From 64e92001db70ff6371859edb893dc5db8a25645f Mon Sep 17 00:00:00 2001 From: Jeroen Doornbos Date: Sun, 29 Sep 2024 15:33:35 +0200 Subject: [PATCH 01/12] added github workflows --- .github/workflows/black.yml | 10 ++++++++++ .github/workflows/pytest.yml | 24 ++++++++++++++++++++++++ README.md | 5 +++++ docs/source/readme.rst | 6 ++++++ 4 files changed, 45 insertions(+) create mode 100644 .github/workflows/black.yml create mode 100644 .github/workflows/pytest.yml diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 0000000..68b72e8 --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,10 @@ +name: Black + +on: [workflow_call, push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: psf/black@stable \ No newline at end of file diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml new file mode 100644 index 0000000..d0b9728 --- /dev/null +++ b/.github/workflows/pytest.yml @@ -0,0 +1,24 @@ +name: PyTest + +on: [workflow_call, push, pull_request] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[dev] + - name: Test with pytest + run: | + pytest \ No newline at end of file diff --git a/README.md b/README.md index 2051751..8409041 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,11 @@ +[![PyTest]([https://github.com/IvoVellekoop/openwfs/actions/workflows/pytest.yml/badge.svg)](https://github.com/IvoVellekoop/openwfs/actions/workflows/pytest.yml](https://github.com/IvoVellekoop/openwfs/actions/workflows/pytest.yml/badge.svg)](https://github.com/IvoVellekoop/openwfs/actions/workflows/pytest.yml)) +[![Black]([https://github.com/IvoVellekoop/openwfs/actions/workflows/black.yml/badge.svg)](https://github.com/IvoVellekoop/openwfs/actions/workflows/black.yml](https://github.com/IvoVellekoop/openwfs/actions/workflows/black.yml/badge.svg)](https://github.com/IvoVellekoop/openwfs/actions/workflows/black.yml)) + +What is wavefront shaping? + ## What is wavefront shaping? Wavefront shaping (WFS) is a technique for controlling the propagation of light in arbitrarily complex structures, including strongly scattering materials [[1](#id62)]. In WFS, a spatial light modulator (SLM) is used to shape the phase and/or amplitude of the incident light. With a properly constructed wavefront, light can be made to focus through [[2](#id48)], or inside [[3](#id37)] scattering materials; or light can be shaped to have other desired properties, such as optimal sensitivity for specific measurements [[4](#id38)], specialized point-spread functions [[5](#id25)] or for functions like optical trapping [[6](#id28)]. diff --git a/docs/source/readme.rst b/docs/source/readme.rst index ebd0fad..87c7d8e 100644 --- a/docs/source/readme.rst +++ b/docs/source/readme.rst @@ -12,6 +12,12 @@ OpenWFS :target: https://openwfs.readthedocs.io/en/latest/?badge=latest :alt: Documentation Status +.. only:: markdown + + [![PyTest](https://github.com/IvoVellekoop/openwfs/actions/workflows/pytest.yml/badge.svg)](https://github.com/IvoVellekoop/openwfs/actions/workflows/pytest.yml) + [![Black](https://github.com/IvoVellekoop/openwfs/actions/workflows/black.yml/badge.svg)](https://github.com/IvoVellekoop/openwfs/actions/workflows/black.yml) + +What is wavefront shaping? What is wavefront shaping? -------------------------------- From 8e0e00aee5ce9ca26f8e0e5cfb4e243383fcba94 Mon Sep 17 00:00:00 2001 From: Jeroen Doornbos Date: Sun, 29 Sep 2024 15:40:47 +0200 Subject: [PATCH 02/12] run black --- docs/source/conf.py | 128 ++++--- examples/hello_simulation.py | 1 + examples/hello_wfs.py | 3 +- examples/mm_scanning_microscope.py | 48 ++- examples/sample_microscope.py | 33 +- examples/slm_demo.py | 4 +- examples/slm_disk.py | 2 +- examples/troubleshooter_demo.py | 15 +- examples/wfs_demonstration_experimental.py | 12 +- openwfs/algorithms/basic_fourier.py | 77 ++-- .../algorithms/custom_iter_dual_reference.py | 118 +++++-- openwfs/algorithms/fourier.py | 114 +++--- openwfs/algorithms/ssa.py | 15 +- openwfs/algorithms/troubleshoot.py | 331 +++++++++++------- openwfs/algorithms/utilities.py | 123 ++++--- openwfs/core.py | 132 +++++-- openwfs/devices/camera.py | 73 ++-- openwfs/devices/galvo_scanner.py | 284 ++++++++++----- openwfs/devices/nidaq_gain.py | 5 +- openwfs/devices/slm/context.py | 3 +- openwfs/devices/slm/geometry.py | 37 +- openwfs/devices/slm/patch.py | 113 ++++-- openwfs/devices/slm/slm.py | 250 +++++++++---- openwfs/devices/slm/texture.py | 88 ++++- openwfs/plot_utilities.py | 8 +- openwfs/processors/__init__.py | 9 +- openwfs/processors/processors.py | 135 ++++--- openwfs/simulation/__init__.py | 9 +- openwfs/simulation/microscope.py | 105 ++++-- openwfs/simulation/mockdevices.py | 90 +++-- openwfs/simulation/slm.py | 104 ++++-- openwfs/simulation/transmission.py | 18 +- openwfs/utilities/__init__.py | 14 +- openwfs/utilities/patterns.py | 48 ++- openwfs/utilities/utilities.py | 179 +++++++--- tests/test_algorithms_troubleshoot.py | 151 +++++--- tests/test_camera.py | 18 +- tests/test_core.py | 44 ++- tests/test_processors.py | 54 +-- tests/test_scanning_microscope.py | 118 +++++-- tests/test_simulation.py | 120 +++++-- tests/test_slm.py | 46 ++- tests/test_utilities.py | 65 +++- tests/test_wfs.py | 302 ++++++++++++---- 44 files changed, 2581 insertions(+), 1065 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index ba1c421..060e523 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -19,24 +19,36 @@ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -extensions = ['sphinx.ext.napoleon', 'sphinx.ext.autodoc', 'sphinx.ext.mathjax', - 'sphinx.ext.viewcode', 'sphinx_autodoc_typehints', 'sphinxcontrib.bibtex', 'sphinx.ext.autosectionlabel', - 'sphinx_markdown_builder', 'sphinx_gallery.gen_gallery'] +extensions = [ + "sphinx.ext.napoleon", + "sphinx.ext.autodoc", + "sphinx.ext.mathjax", + "sphinx.ext.viewcode", + "sphinx_autodoc_typehints", + "sphinxcontrib.bibtex", + "sphinx.ext.autosectionlabel", + "sphinx_markdown_builder", + "sphinx_gallery.gen_gallery", +] # basic project information -project = 'OpenWFS' -copyright = '2023-, Ivo Vellekoop, Daniël W. S. Cox, and Jeroen H. Doornbos, University of Twente' -author = 'Jeroen H. Doornbos, Daniël W. S. Cox, Tom Knop, Harish Sasikumar, Ivo M. Vellekoop' -release = '0.1.0rc2' -html_title = "OpenWFS - a library for conducting and simulating wavefront shaping experiments" +project = "OpenWFS" +copyright = "2023-, Ivo Vellekoop, Daniël W. S. Cox, and Jeroen H. Doornbos, University of Twente" +author = ( + "Jeroen H. Doornbos, Daniël W. S. Cox, Tom Knop, Harish Sasikumar, Ivo M. Vellekoop" +) +release = "0.1.0rc2" +html_title = ( + "OpenWFS - a library for conducting and simulating wavefront shaping experiments" +) # \renewenvironment{sphinxtheindex}{\setbox0\vbox\bgroup\begin{theindex}}{\end{theindex}} # latex configuration latex_elements = { - 'preamble': r""" + "preamble": r""" \usepackage{authblk} """, - 'maketitle': r""" + "maketitle": r""" \author[1]{Daniël~W.~S.~Cox} \author[1]{Tom~Knop} \author[1,2]{Harish~Sasikumar} @@ -67,40 +79,52 @@ } \maketitle """, - 'tableofcontents': "", - 'makeindex': "", - 'printindex': "", - 'figure_align': "", - 'extraclassoptions': 'notitlepage', + "tableofcontents": "", + "makeindex": "", + "printindex": "", + "figure_align": "", + "extraclassoptions": "notitlepage", } latex_docclass = { - 'manual': 'scrartcl', - 'howto': 'scrartcl', + "manual": "scrartcl", + "howto": "scrartcl", } -latex_documents = [('index_latex', 'OpenWFS.tex', - 'OpenWFS - a library for conducting and simulating wavefront shaping experiments', - 'Jeroen H. Doornbos', 'howto')] -latex_toplevel_sectioning = 'section' -bibtex_default_style = 'unsrt' -bibtex_bibfiles = ['references.bib'] +latex_documents = [ + ( + "index_latex", + "OpenWFS.tex", + "OpenWFS - a library for conducting and simulating wavefront shaping experiments", + "Jeroen H. Doornbos", + "howto", + ) +] +latex_toplevel_sectioning = "section" +bibtex_default_style = "unsrt" +bibtex_bibfiles = ["references.bib"] numfig = True -templates_path = ['_templates'] -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'acknowledgements.rst', 'sg_execution_times.rst'] -master_doc = '' -include_patterns = ['**'] +templates_path = ["_templates"] +exclude_patterns = [ + "_build", + "Thumbs.db", + ".DS_Store", + "acknowledgements.rst", + "sg_execution_times.rst", +] +master_doc = "" +include_patterns = ["**"] napoleon_use_rtype = False napoleon_use_param = True typehints_document_rtype = False -latex_engine = 'xelatex' -html_theme = 'sphinx_rtd_theme' +latex_engine = "xelatex" +html_theme = "sphinx_rtd_theme" add_module_names = False autodoc_preserve_defaults = True sphinx_gallery_conf = { - 'examples_dirs': '../../examples', # path to your example scripts - 'ignore_pattern': 'set_path.py', - 'gallery_dirs': 'auto_examples', # path to where to save gallery generated output + "examples_dirs": "../../examples", # path to your example scripts + "ignore_pattern": "set_path.py", + "gallery_dirs": "auto_examples", # path to where to save gallery generated output } # importing this module without OpenGL installed will fail, @@ -117,7 +141,7 @@ def skip(app, what, name, obj, skip, options): def visit_citation(self, node): """Patch-in function for markdown builder to support citations.""" - id = node['ids'][0] + id = node["ids"][0] self.add(f'') @@ -141,29 +165,35 @@ def setup(app): def source_read(app, docname, source): - if docname == 'readme' or docname == 'conclusion': - if (app.builder.name == 'latex') == (docname == 'conclusion'): - source[0] = source[0].replace('%endmatter%', '.. include:: acknowledgements.rst') + if docname == "readme" or docname == "conclusion": + if (app.builder.name == "latex") == (docname == "conclusion"): + source[0] = source[0].replace( + "%endmatter%", ".. include:: acknowledgements.rst" + ) else: - source[0] = source[0].replace('%endmatter%', '') + source[0] = source[0].replace("%endmatter%", "") def builder_inited(app): - if app.builder.name == 'html': - exclude_patterns.extend(['conclusion.rst', 'index_latex.rst', 'index_markdown.rst']) - app.config.master_doc = 'index' - elif app.builder.name == 'latex': - exclude_patterns.extend(['auto_examples/*', 'index_markdown.rst', 'index.rst', 'api*']) - app.config.master_doc = 'index_latex' - elif app.builder.name == 'markdown': + if app.builder.name == "html": + exclude_patterns.extend( + ["conclusion.rst", "index_latex.rst", "index_markdown.rst"] + ) + app.config.master_doc = "index" + elif app.builder.name == "latex": + exclude_patterns.extend( + ["auto_examples/*", "index_markdown.rst", "index.rst", "api*"] + ) + app.config.master_doc = "index_latex" + elif app.builder.name == "markdown": include_patterns.clear() - include_patterns.extend(['readme.rst', 'index_markdown.rst']) - app.config.master_doc = 'index_markdown' + include_patterns.extend(["readme.rst", "index_markdown.rst"]) + app.config.master_doc = "index_markdown" def copy_readme(app, exception): """Copy the readme file to the root of the documentation directory.""" - if exception is None and app.builder.name == 'markdown': - source_file = Path(app.outdir) / 'readme.md' - destination_dir = Path(app.confdir).parents[1] / 'README.md' + if exception is None and app.builder.name == "markdown": + source_file = Path(app.outdir) / "readme.md" + destination_dir = Path(app.confdir).parents[1] / "README.md" shutil.copy(source_file, destination_dir) diff --git a/examples/hello_simulation.py b/examples/hello_simulation.py index 57926a7..722e097 100644 --- a/examples/hello_simulation.py +++ b/examples/hello_simulation.py @@ -2,6 +2,7 @@ =============================================== Simulates a wavefront shaping experiment using a SimulatedWFS object, which acts both as a spatial light modulator (SLM) and a detector.""" + import numpy as np from openwfs.algorithms import StepwiseSequential diff --git a/examples/hello_wfs.py b/examples/hello_wfs.py index 28912d5..be7015f 100644 --- a/examples/hello_wfs.py +++ b/examples/hello_wfs.py @@ -7,6 +7,7 @@ and a spatial light modulator (SLM) connected to the secondary video output. """ + import numpy as np from openwfs.algorithms import StepwiseSequential @@ -18,7 +19,7 @@ # Connect to a GenICam camera, average pixels to get feedback signal camera = Camera(R"C:\Program Files\Basler\pylon 7\Runtime\x64\ProducerU3V.cti") -feedback = SingleRoi(camera, pos=(320, 320), mask_type='disk', radius=2.5) +feedback = SingleRoi(camera, pos=(320, 320), mask_type="disk", radius=2.5) # Run the algorithm alg = StepwiseSequential(feedback=feedback, slm=slm, n_x=10, n_y=10, phase_steps=4) diff --git a/examples/mm_scanning_microscope.py b/examples/mm_scanning_microscope.py index ec32c9d..3bb96d1 100644 --- a/examples/mm_scanning_microscope.py +++ b/examples/mm_scanning_microscope.py @@ -20,31 +20,51 @@ optical_deflection=1.0 / (0.22 * u.V / u.deg), galvo_to_pupil_magnification=2, objective_magnification=16, - reference_tube_lens=200 * u.mm) + reference_tube_lens=200 * u.mm, +) acceleration = Axis.compute_acceleration( optical_deflection=1.0 / (0.22 * u.V / u.deg), - torque_constant=2.8E5 * u.dyne * u.cm / u.A, - rotor_inertia=8.25 * u.g * u.cm ** 2, - maximum_current=4 * u.A) + torque_constant=2.8e5 * u.dyne * u.cm / u.A, + rotor_inertia=8.25 * u.g * u.cm**2, + maximum_current=4 * u.A, +) # scale = 440 * u.um / u.V (calibrated) sample_rate = 0.5 * u.MHz reference_zoom = 1.2 -y_axis = Axis(channel='Dev4/ao0', v_min=-2.0 * u.V, v_max=2.0 * u.V, maximum_acceleration=acceleration, scale=scale) -x_axis = Axis(channel='Dev4/ao1', v_min=-2.0 * u.V, v_max=2.0 * u.V, maximum_acceleration=acceleration, scale=scale) -input_channel = InputChannel('Dev4/ai0', -1.0 * u.V, 1.0 * u.V) +y_axis = Axis( + channel="Dev4/ao0", + v_min=-2.0 * u.V, + v_max=2.0 * u.V, + maximum_acceleration=acceleration, + scale=scale, +) +x_axis = Axis( + channel="Dev4/ao1", + v_min=-2.0 * u.V, + v_max=2.0 * u.V, + maximum_acceleration=acceleration, + scale=scale, +) +input_channel = InputChannel("Dev4/ai0", -1.0 * u.V, 1.0 * u.V) test_image = skimage.data.hubble_deep_field() * 256 -scanner = ScanningMicroscope(sample_rate=sample_rate, - input=input_channel, y_axis=y_axis, x_axis=x_axis, - test_pattern='image', reference_zoom=reference_zoom, - resolution=1024, test_image=test_image) +scanner = ScanningMicroscope( + sample_rate=sample_rate, + input=input_channel, + y_axis=y_axis, + x_axis=x_axis, + test_pattern="image", + reference_zoom=reference_zoom, + resolution=1024, + test_image=test_image, +) -if __name__ == '__main__': +if __name__ == "__main__": scanner.binning = 4 - plt.imshow(scanner.read(), cmap='gray') + plt.imshow(scanner.read(), cmap="gray") plt.colorbar() plt.show() else: - devices = {'microscope': scanner} + devices = {"microscope": scanner} diff --git a/examples/sample_microscope.py b/examples/sample_microscope.py index bdc2057..9259594 100644 --- a/examples/sample_microscope.py +++ b/examples/sample_microscope.py @@ -41,27 +41,42 @@ # Code img = set_pixel_size( - np.maximum(np.random.randint(-10000, 100, (img_size_y, img_size_x), dtype=np.int16), 0), - 60 * u.nm) + np.maximum( + np.random.randint(-10000, 100, (img_size_y, img_size_x), dtype=np.int16), 0 + ), + 60 * u.nm, +) src = StaticSource(img) -mic = Microscope(src, magnification=magnification, numerical_aperture=numerical_aperture, wavelength=wavelength) +mic = Microscope( + src, + magnification=magnification, + numerical_aperture=numerical_aperture, + wavelength=wavelength, +) # simulate shot noise in an 8-bit camera with auto-exposure: -cam = mic.get_camera(shot_noise=True, digital_max=255, data_shape=camera_resolution, pixel_size=pixel_size) -devices = {'camera': cam, 'stage': mic.xy_stage} +cam = mic.get_camera( + shot_noise=True, + digital_max=255, + data_shape=camera_resolution, + pixel_size=pixel_size, +) +devices = {"camera": cam, "stage": mic.xy_stage} -if __name__ == '__main__': +if __name__ == "__main__": import matplotlib.pyplot as plt plt.subplot(1, 2, 1) imshow(img) - plt.title('Original image') + plt.title("Original image") plt.subplot(1, 2, 2) - plt.title('Scanned image') + plt.title("Scanned image") ax = None for p in range(p_limit): mic.xy_stage.x = p * 1 * u.um mic.numerical_aperture = 1.0 * (p + 1) / p_limit # NA increases to 1.0 ax = grab_and_show(cam, ax) - plt.title(f"NA: {mic.numerical_aperture}, δ: {mic.abbe_limit.to_value(u.um):2.2} μm") + plt.title( + f"NA: {mic.numerical_aperture}, δ: {mic.abbe_limit.to_value(u.um):2.2} μm" + ) plt.pause(0.2) diff --git a/examples/slm_demo.py b/examples/slm_demo.py index bbb64c9..aba50bf 100644 --- a/examples/slm_demo.py +++ b/examples/slm_demo.py @@ -34,7 +34,9 @@ p4.phases = 1 p4.additive_blend = False -pf.phases = patterns.lens(100, f=1 * u.m, wavelength=0.8 * u.um, extent=(10 * u.mm, 10 * u.mm)) +pf.phases = patterns.lens( + 100, f=1 * u.m, wavelength=0.8 * u.um, extent=(10 * u.mm, 10 * u.mm) +) rng = np.random.default_rng() for n in range(200): random_data = rng.random([10, 10], np.float32) * 2.0 * np.pi diff --git a/examples/slm_disk.py b/examples/slm_disk.py index cbe79c0..87d8260 100644 --- a/examples/slm_disk.py +++ b/examples/slm_disk.py @@ -29,4 +29,4 @@ # read back the pixels and store in a file pixels = slm.pixels.read() -cv2.imwrite('slm_disk.png', pixels) +cv2.imwrite("slm_disk.png", pixels) diff --git a/examples/troubleshooter_demo.py b/examples/troubleshooter_demo.py index 706e1ee..abf08d6 100644 --- a/examples/troubleshooter_demo.py +++ b/examples/troubleshooter_demo.py @@ -31,12 +31,16 @@ # Simulate an SLM with incorrect phase response # Also simulate a shutter that can turn off the light # The SLM is conjugated to the back pupil plane -slm = SLM(shape=(100, 100), - phase_response=(np.arange(256) / 128 * np.pi) * 1.2) +slm = SLM(shape=(100, 100), phase_response=(np.arange(256) / 128 * np.pi) * 1.2) shutter = Shutter(slm.field) # Simulate a WFS microscope looking at the specimen -sim = Microscope(source=specimen, incident_field=shutter, aberrations=aberrations, wavelength=800 * u.nm) +sim = Microscope( + source=specimen, + incident_field=shutter, + aberrations=aberrations, + wavelength=800 * u.nm, +) # Simulate a camera device with gaussian noise and shot noise cam = sim.get_camera(analog_max=1e4, shot_noise=True, gaussian_noise_std=4.0) @@ -52,6 +56,7 @@ roi_background = SingleRoi(cam, radius=10) # Run WFS troubleshooter and output a report to the console -trouble = troubleshoot(algorithm=alg, background_feedback=roi_background, - frame_source=cam, shutter=shutter) +trouble = troubleshoot( + algorithm=alg, background_feedback=roi_background, frame_source=cam, shutter=shutter +) trouble.report() diff --git a/examples/wfs_demonstration_experimental.py b/examples/wfs_demonstration_experimental.py index ef845c3..e4998a2 100644 --- a/examples/wfs_demonstration_experimental.py +++ b/examples/wfs_demonstration_experimental.py @@ -19,12 +19,20 @@ # constructs the actual slm for wavefront shaping, and a monitor window to display the current phase pattern slm = SLM(monitor_id=2, duration=2) -monitor = slm.clone(monitor_id=0, pos=(0, 0), shape=(slm.shape[0] // 4, slm.shape[1] // 4)) +monitor = slm.clone( + monitor_id=0, pos=(0, 0), shape=(slm.shape[0] // 4, slm.shape[1] // 4) +) # we are using a setup with an SLM that produces 2pi phase shift # at a gray value of 142 slm.lookup_table = range(142) -alg = FourierDualReference(feedback=roi_detector, slm=slm, slm_shape=[800, 800], k_angles_min=-5, k_angles_max=5) +alg = FourierDualReference( + feedback=roi_detector, + slm=slm, + slm_shape=[800, 800], + k_angles_min=-5, + k_angles_max=5, +) result = alg.execute() print(result) diff --git a/openwfs/algorithms/basic_fourier.py b/openwfs/algorithms/basic_fourier.py index 20d5dd6..a8c2695 100644 --- a/openwfs/algorithms/basic_fourier.py +++ b/openwfs/algorithms/basic_fourier.py @@ -24,12 +24,16 @@ def build_square_k_space(k_min, k_max, k_step=1.0): """ # Generate kx and ky coordinates kx_angles = np.arange(k_min, k_max + 1, 1) - k_angles_min_even = (k_min if k_min % 2 == 0 else k_min + 1) # Must be even - ky_angles = np.arange(k_angles_min_even, k_max + 1, 2) # Steps of 2 + k_angles_min_even = k_min if k_min % 2 == 0 else k_min + 1 # Must be even + ky_angles = np.arange(k_angles_min_even, k_max + 1, 2) # Steps of 2 # Combine kx and ky coordinates into pairs - k_x = np.repeat(np.array(kx_angles)[np.newaxis, :], len(ky_angles), axis=0).flatten() - k_y = np.repeat(np.array(ky_angles)[:, np.newaxis], len(kx_angles), axis=1).flatten() + k_x = np.repeat( + np.array(kx_angles)[np.newaxis, :], len(ky_angles), axis=0 + ).flatten() + k_y = np.repeat( + np.array(ky_angles)[:, np.newaxis], len(kx_angles), axis=1 + ).flatten() k_space = np.vstack((k_x, k_y)) * k_step return k_space @@ -46,8 +50,16 @@ class FourierDualReference(FourierBase): "Wavefront shaping for forward scattering," Opt. Express 30, 37436-37445 (2022) """ - def __init__(self, feedback: Detector, slm: PhaseSLM, slm_shape=(500, 500), phase_steps=4, k_angles_min: int = -3, - k_angles_max: int = 3, analyzer: Optional[callable] = analyze_phase_stepping): + def __init__( + self, + feedback: Detector, + slm: PhaseSLM, + slm_shape=(500, 500), + phase_steps=4, + k_angles_min: int = -3, + k_angles_max: int = 3, + analyzer: Optional[callable] = analyze_phase_stepping, + ): """ Args: feedback (Detector): Source of feedback @@ -58,8 +70,15 @@ def __init__(self, feedback: Detector, slm: PhaseSLM, slm_shape=(500, 500), phas k_angles_min (int): The minimum k-angle. k_angles_max (int): The maximum k-angle. """ - super().__init__(feedback, slm, slm_shape, np.array((0, 0)), np.array((0, 0)), phase_steps=phase_steps, - analyzer=analyzer) + super().__init__( + feedback, + slm, + slm_shape, + np.array((0, 0)), + np.array((0, 0)), + phase_steps=phase_steps, + analyzer=analyzer, + ) self._k_angles_min = k_angles_min self._k_angles_max = k_angles_max @@ -86,7 +105,7 @@ def k_angles_min(self) -> int: @k_angles_min.setter def k_angles_min(self, value): """Sets the lower bound of the range of angles in x and y direction, triggers the building of the internal - k-space properties. + k-space properties. """ self._k_angles_min = value self._build_k_space() @@ -99,7 +118,7 @@ def k_angles_max(self) -> int: @k_angles_max.setter def k_angles_max(self, value): """Sets the higher bound of the range of angles in x and y direction, triggers the building of the internal - k-space properties.""" + k-space properties.""" self._k_angles_max = value self._build_k_space() @@ -112,8 +131,17 @@ class FourierDualReferenceCircle(FourierBase): [1]: Bahareh Mastiani, Gerwin Osnabrugge, and Ivo M. Vellekoop, "Wavefront shaping for forward scattering," Opt. Express 30, 37436-37445 (2022) """ - def __init__(self, feedback: Detector, slm: PhaseSLM, slm_shape=(500, 500), phase_steps=4, k_radius: float = 3.2, - k_step: float = 1.0, analyzer: Optional[callable] = analyze_phase_stepping): + + def __init__( + self, + feedback: Detector, + slm: PhaseSLM, + slm_shape=(500, 500), + phase_steps=4, + k_radius: float = 3.2, + k_step: float = 1.0, + analyzer: Optional[callable] = analyze_phase_stepping, + ): """ Args: feedback (Detector): Source of feedback @@ -128,8 +156,15 @@ def __init__(self, feedback: Detector, slm: PhaseSLM, slm_shape=(500, 500), phas # first build the k_space, then call super().__init__ with k_left=k_space, k_right=k_space. # TODO: Add custom grid spacing - super().__init__(feedback=feedback, slm=slm, slm_shape=slm_shape, k_left=np.array((0, 0)), - k_right=np.array((0, 0)), phase_steps=phase_steps, analyzer=analyzer) + super().__init__( + feedback=feedback, + slm=slm, + slm_shape=slm_shape, + k_left=np.array((0, 0)), + k_right=np.array((0, 0)), + phase_steps=phase_steps, + analyzer=analyzer, + ) self._k_radius = k_radius self.k_step = k_step @@ -150,7 +185,7 @@ def _build_k_space(self): k_space_square = build_square_k_space(-k_max, k_max, k_step=k_step) # Filter out k-space coordinates that are outside the circle of radius k_radius - k_mask = (np.linalg.norm(k_space_square, axis=0) <= k_radius) + k_mask = np.linalg.norm(k_space_square, axis=0) <= k_radius k_space = k_space_square[:, k_mask] self.k_left = k_space @@ -172,9 +207,9 @@ def plot_k_space(self): phi = np.linspace(0, 2 * np.pi, 200) x = self.k_radius * np.cos(phi) y = self.k_radius * np.sin(phi) - plt.plot(x, y, 'k') - plt.plot(self.k_left[0, :], self.k_left[1, :], 'ob', label='k_left') - plt.plot(self.k_right[0, :], self.k_right[1, :], '.r', label='k_right') - plt.xlabel('k_x') - plt.ylabel('k_y') - plt.gca().set_aspect('equal') + plt.plot(x, y, "k") + plt.plot(self.k_left[0, :], self.k_left[1, :], "ob", label="k_left") + plt.plot(self.k_right[0, :], self.k_right[1, :], ".r", label="k_right") + plt.xlabel("k_x") + plt.ylabel("k_y") + plt.gca().set_aspect("equal") diff --git a/openwfs/algorithms/custom_iter_dual_reference.py b/openwfs/algorithms/custom_iter_dual_reference.py index 5519595..ec26fdf 100644 --- a/openwfs/algorithms/custom_iter_dual_reference.py +++ b/openwfs/algorithms/custom_iter_dual_reference.py @@ -44,9 +44,20 @@ class CustomIterativeDualReference: https://opg.optica.org/oe/ abstract.cfm?uri=oe-27-8-1167 """ - def __init__(self, feedback: Detector, slm: PhaseSLM, slm_shape: tuple[int, int], phases: tuple[nd, nd], set1_mask: - nd, phase_steps: int = 4, iterations: int = 4, analyzer: Optional[callable] = analyze_phase_stepping, - do_try_full_patterns=False, do_progress_bar=True, progress_bar_kwargs={}): + def __init__( + self, + feedback: Detector, + slm: PhaseSLM, + slm_shape: tuple[int, int], + phases: tuple[nd, nd], + set1_mask: nd, + phase_steps: int = 4, + iterations: int = 4, + analyzer: Optional[callable] = analyze_phase_stepping, + do_try_full_patterns=False, + do_progress_bar=True, + progress_bar_kwargs={}, + ): """ Args: feedback (Detector): The feedback source, usually a detector that provides measurement data. @@ -83,7 +94,9 @@ def __init__(self, feedback: Detector, slm: PhaseSLM, slm_shape: tuple[int, int] self.do_progress_bar = do_progress_bar self.progress_bar_kwargs = progress_bar_kwargs - assert (phases[0].shape[0] == phases[1].shape[0]) and (phases[0].shape[1] == phases[1].shape[1]) + assert (phases[0].shape[0] == phases[1].shape[0]) and ( + phases[0].shape[1] == phases[1].shape[1] + ) self.phases = (phases[0].astype(np.float32), phases[1].astype(np.float32)) # Pre-compute set0 mask @@ -113,13 +126,20 @@ def execute(self) -> WFSResult: t_set = t_full t_set_all = [None] * self.iterations results_all = [None] * self.iterations # List to store all results - results_latest = [None, None] # The two latest results. Used for computing fidelity factors. - full_pattern_feedback = np.zeros(self.iterations) # List to store feedback from full patterns + results_latest = [ + None, + None, + ] # The two latest results. Used for computing fidelity factors. + full_pattern_feedback = np.zeros( + self.iterations + ) # List to store feedback from full patterns # Prepare progress bar if self.do_progress_bar: - num_measurements = np.ceil(self.iterations / 2) * self.modes[0].shape[2] \ - + np.floor(self.iterations / 2) * self.modes[1].shape[2] + num_measurements = ( + np.ceil(self.iterations / 2) * self.modes[0].shape[2] + + np.floor(self.iterations / 2) * self.modes[1].shape[2] + ) progress_bar = tqdm(total=num_measurements, **self.progress_bar_kwargs) else: progress_bar = None @@ -129,12 +149,20 @@ def execute(self) -> WFSResult: s = it % 2 # Set id: 0 or 1. Used to pick set A or B for phase stepping mod_mask = self.set_masks[s] t_prev = t_set - ref_phases = -np.angle(t_prev) # Shaped reference phase pattern from transmission matrix + ref_phases = -np.angle( + t_prev + ) # Shaped reference phase pattern from transmission matrix # Measure and compute - result = self._single_side_experiment(mod_phases=self.phases[s], ref_phases=ref_phases, - mod_mask=mod_mask, progress_bar=progress_bar) - t_set = self.compute_t_set(result, self.modes[s]) # Compute transmission matrix from measurements + result = self._single_side_experiment( + mod_phases=self.phases[s], + ref_phases=ref_phases, + mod_mask=mod_mask, + progress_bar=progress_bar, + ) + t_set = self.compute_t_set( + result, self.modes[s] + ) # Compute transmission matrix from measurements # Store results t_full = t_prev + t_set @@ -148,23 +176,34 @@ def execute(self) -> WFSResult: full_pattern_feedback[it] = self.feedback.read() # Compute average fidelity factors - fidelity_noise = weighted_average(results_latest[0].fidelity_noise, - results_latest[1].fidelity_noise, results_latest[0].n, - results_latest[1].n) - fidelity_amplitude = weighted_average(results_latest[0].fidelity_amplitude, - results_latest[1].fidelity_amplitude, results_latest[0].n, - results_latest[1].n) - fidelity_calibration = weighted_average(results_latest[0].fidelity_calibration, - results_latest[1].fidelity_calibration, results_latest[0].n, - results_latest[1].n) - - result = WFSResult(t=t_full, - t_f=None, - n=self.modes[0].shape[2]+self.modes[1].shape[2], - axis=2, - fidelity_noise=fidelity_noise, - fidelity_amplitude=fidelity_amplitude, - fidelity_calibration=fidelity_calibration) + fidelity_noise = weighted_average( + results_latest[0].fidelity_noise, + results_latest[1].fidelity_noise, + results_latest[0].n, + results_latest[1].n, + ) + fidelity_amplitude = weighted_average( + results_latest[0].fidelity_amplitude, + results_latest[1].fidelity_amplitude, + results_latest[0].n, + results_latest[1].n, + ) + fidelity_calibration = weighted_average( + results_latest[0].fidelity_calibration, + results_latest[1].fidelity_calibration, + results_latest[0].n, + results_latest[1].n, + ) + + result = WFSResult( + t=t_full, + t_f=None, + n=self.modes[0].shape[2] + self.modes[1].shape[2], + axis=2, + fidelity_noise=fidelity_noise, + fidelity_amplitude=fidelity_amplitude, + fidelity_calibration=fidelity_calibration, + ) # TODO: This is a dirty way to add attributes. Find better way. result.t_set_all = t_set_all @@ -172,8 +211,13 @@ def execute(self) -> WFSResult: result.full_pattern_feedback = full_pattern_feedback return result - def _single_side_experiment(self, mod_phases: nd, ref_phases: nd, mod_mask: nd, - progress_bar: Optional[tqdm] = None) -> WFSResult: + def _single_side_experiment( + self, + mod_phases: nd, + ref_phases: nd, + mod_mask: nd, + progress_bar: Optional[tqdm] = None, + ) -> WFSResult: """ Conducts experiments on one part of the SLM. @@ -195,13 +239,19 @@ def _single_side_experiment(self, mod_phases: nd, ref_phases: nd, mod_mask: nd, With float32: phase_pattern = phases_B + (phases_A + step) * mask ~2ms per phase pattern """ num_of_modes = mod_phases.shape[2] - measurements = np.zeros((num_of_modes, self.phase_steps, *self.feedback.data_shape)) - ref_phases_masked = (1.0 - mod_mask) * ref_phases # Pre-compute masked reference phase pattern + measurements = np.zeros( + (num_of_modes, self.phase_steps, *self.feedback.data_shape) + ) + ref_phases_masked = ( + 1.0 - mod_mask + ) * ref_phases # Pre-compute masked reference phase pattern for m in range(num_of_modes): for p in range(self.phase_steps): phase_step = p * 2 * np.pi / self.phase_steps - phase_pattern = ref_phases_masked + mod_mask * (mod_phases[:, :, m] + phase_step) + phase_pattern = ref_phases_masked + mod_mask * ( + mod_phases[:, :, m] + phase_step + ) self.slm.set_phases(phase_pattern) self.feedback.trigger(out=measurements[m, p, ...]) diff --git a/openwfs/algorithms/fourier.py b/openwfs/algorithms/fourier.py index 71f52c8..65b0228 100644 --- a/openwfs/algorithms/fourier.py +++ b/openwfs/algorithms/fourier.py @@ -10,21 +10,29 @@ class FourierBase: """Base class definition for the Fourier algorithms as described in [1]. - This algorithm optimises the wavefront in a Fourier-basis. The modes that are tested are provided into a 'k-space' - of which each 'k-vector' represents a certain angled wavefront that will be tested. (more detailed explanation is - found in _get_phase_pattern). - - As described in [1], these modes are measured by interfering a certain mode on one half of the SLM with a - 'reference beam'. This is done by not modulating the other half of the SLM. In order to find a full corrective - wavefront therefore, the experiment has to be repeated twice for each side of the SLM. Finally, the two wavefronts - are combined. - - [1]: Bahareh Mastiani, Gerwin Osnabrugge, and Ivo M. Vellekoop, - "Wavefront shaping for forward scattering," Opt. Express 30, 37436-37445 (2022) - """ - - def __init__(self, feedback: Detector, slm: PhaseSLM, slm_shape: tuple[int, int], k_left: np.ndarray, - k_right: np.ndarray, phase_steps: int = 4, analyzer: Optional[callable] = analyze_phase_stepping): + This algorithm optimises the wavefront in a Fourier-basis. The modes that are tested are provided into a 'k-space' + of which each 'k-vector' represents a certain angled wavefront that will be tested. (more detailed explanation is + found in _get_phase_pattern). + + As described in [1], these modes are measured by interfering a certain mode on one half of the SLM with a + 'reference beam'. This is done by not modulating the other half of the SLM. In order to find a full corrective + wavefront therefore, the experiment has to be repeated twice for each side of the SLM. Finally, the two wavefronts + are combined. + + [1]: Bahareh Mastiani, Gerwin Osnabrugge, and Ivo M. Vellekoop, + "Wavefront shaping for forward scattering," Opt. Express 30, 37436-37445 (2022) + """ + + def __init__( + self, + feedback: Detector, + slm: PhaseSLM, + slm_shape: tuple[int, int], + k_left: np.ndarray, + k_right: np.ndarray, + phase_steps: int = 4, + analyzer: Optional[callable] = analyze_phase_stepping, + ): """ Args: @@ -88,7 +96,9 @@ def _single_side_experiment(self, k_set: np.ndarray, side: int) -> WFSResult: Returns: WFSResult: An object containing the computed SLM transmission matrix and related data. """ - measurements = np.zeros((k_set.shape[1], self.phase_steps, *self.feedback.data_shape)) + measurements = np.zeros( + (k_set.shape[1], self.phase_steps, *self.feedback.data_shape) + ) for i in range(k_set.shape[1]): for p in range(self.phase_steps): @@ -100,10 +110,9 @@ def _single_side_experiment(self, k_set: np.ndarray, side: int) -> WFSResult: self.feedback.wait() return self.analyzer(measurements, axis=1) - def _get_phase_pattern(self, - k: np.ndarray, - phase_offset: float, - side: int) -> np.ndarray: + def _get_phase_pattern( + self, k: np.ndarray, phase_offset: float, side: int + ) -> np.ndarray: """ Generates a phase pattern for the SLM based on the given spatial frequency, phase offset, and side. @@ -119,8 +128,12 @@ def _get_phase_pattern(self, # The natural step to take is the Abbe diffraction limit of the modulated part, which corresponds to a gradient # from -π to π over the modulated part. num_columns = self.slm_shape[1] // 2 - tilted_front = tilt([self.slm_shape[0], num_columns], k * (0.5 * np.pi), extent=(2.0, 1.0), - phase_offset=phase_offset) + tilted_front = tilt( + [self.slm_shape[0], num_columns], + k * (0.5 * np.pi), + extent=(2.0, 1.0), + phase_offset=phase_offset, + ) # Handle side-dependent pattern @@ -136,7 +149,9 @@ def _get_phase_pattern(self, return result - def compute_t(self, left: WFSResult, right: WFSResult, k_left, k_right) -> WFSResult: + def compute_t( + self, left: WFSResult, right: WFSResult, k_left, k_right + ) -> WFSResult: """ Computes the SLM transmission matrix by combining the Fourier transmission matrices from both sides of the SLM. @@ -152,8 +167,8 @@ def compute_t(self, left: WFSResult, right: WFSResult, k_left, k_right) -> WFSRe # TODO: determine noise # Initialize transmission matrices - t1 = np.zeros((*self.slm_shape, *self.feedback.data_shape), dtype='complex128') - t2 = np.zeros((*self.slm_shape, *self.feedback.data_shape), dtype='complex128') + t1 = np.zeros((*self.slm_shape, *self.feedback.data_shape), dtype="complex128") + t2 = np.zeros((*self.slm_shape, *self.feedback.data_shape), dtype="complex128") # Calculate phase difference between the two halves # We have two phase stepping measurements where both halves are flat (k=0) @@ -163,12 +178,18 @@ def compute_t(self, left: WFSResult, right: WFSResult, k_left, k_right) -> WFSRe # Find the index of the (0,0) mode in k_left and k_right index_0_left = np.argmin(k_left[0] ** 2 + k_left[1] ** 2) index_0_right = np.argmin(k_right[0] ** 2 + k_left[1] ** 2) - if not np.all(k_left[:, index_0_left] == 0.0) or not np.all(k_right[:, index_0_right] == 0.0): - raise Exception("k=(0,0) component missing from the measurement set, cannot determine relative phase.") + if not np.all(k_left[:, index_0_left] == 0.0) or not np.all( + k_right[:, index_0_right] == 0.0 + ): + raise Exception( + "k=(0,0) component missing from the measurement set, cannot determine relative phase." + ) # average the measurements for better accuracy # TODO: absolute values are not the same in simulation, 'A' scaling is off? - relative = 0.5 * (left.t[index_0_left, ...] + np.conjugate(right.t[index_0_right, ...])) + relative = 0.5 * ( + left.t[index_0_left, ...] + np.conjugate(right.t[index_0_right, ...]) + ) # Apply phase correction to the right side phase_correction = relative / np.abs(relative) @@ -181,12 +202,21 @@ def compute_t(self, left: WFSResult, right: WFSResult, k_left, k_right) -> WFSRe for n, t in enumerate(right.t): phi = self._get_phase_pattern(k_right[:, n], 0, 1) - t2 += np.tensordot(np.exp(-1j * phi), t * (normalisation * phase_correction), 0) + t2 += np.tensordot( + np.exp(-1j * phi), t * (normalisation * phase_correction), 0 + ) # Combine the left and right sides - t_full = np.concatenate([t1[:, :self.slm_shape[0] // 2, ...], t2[:, self.slm_shape[0] // 2:, ...]], axis=1) - t_f_full = np.concatenate([left.t_f, right.t_f], - axis=1) # also store raw data (not normalized or corrected yet!) + t_full = np.concatenate( + [ + t1[:, : self.slm_shape[0] // 2, ...], + t2[:, self.slm_shape[0] // 2 :, ...], + ], + axis=1, + ) + t_f_full = np.concatenate( + [left.t_f, right.t_f], axis=1 + ) # also store raw data (not normalized or corrected yet!) # return combined result, along with a course estimate of the snr and expected enhancement # TODO: not accurate yet @@ -194,10 +224,16 @@ def compute_t(self, left: WFSResult, right: WFSResult, k_left, k_right) -> WFSRe def weighted_average(x_left, x_right): return (left.n * x_left + right.n * x_right) / (left.n + right.n) - return WFSResult(t=t_full, - t_f=t_f_full, - n=left.n + right.n, - axis=2, - fidelity_noise=weighted_average(left.fidelity_noise, right.fidelity_noise), - fidelity_amplitude=weighted_average(left.fidelity_amplitude, right.fidelity_amplitude), - fidelity_calibration=weighted_average(left.fidelity_calibration, right.fidelity_calibration)) + return WFSResult( + t=t_full, + t_f=t_f_full, + n=left.n + right.n, + axis=2, + fidelity_noise=weighted_average(left.fidelity_noise, right.fidelity_noise), + fidelity_amplitude=weighted_average( + left.fidelity_amplitude, right.fidelity_amplitude + ), + fidelity_calibration=weighted_average( + left.fidelity_calibration, right.fidelity_calibration + ), + ) diff --git a/openwfs/algorithms/ssa.py b/openwfs/algorithms/ssa.py index c86c36a..7403c31 100644 --- a/openwfs/algorithms/ssa.py +++ b/openwfs/algorithms/ssa.py @@ -15,7 +15,14 @@ class StepwiseSequential: [2]: Ivo M. Vellekoop, "Feedback-based wavefront shaping," Opt. Express 23, 12189-12206 (2015) """ - def __init__(self, feedback: Detector, slm: PhaseSLM, phase_steps: int = 4, n_x: int = 4, n_y: int = None): + def __init__( + self, + feedback: Detector, + slm: PhaseSLM, + phase_steps: int = 4, + n_x: int = 4, + n_y: int = None, + ): """ This algorithm systematically modifies the phase pattern of each SLM element and measures the resulting feedback. @@ -39,8 +46,10 @@ def execute(self) -> WFSResult: Returns: WFSResult: An object containing the computed transmission matrix and statistics. """ - phase_pattern = np.zeros((self.n_y, self.n_x), 'float32') - measurements = np.zeros((self.n_y, self.n_x, self.phase_steps, *self.feedback.data_shape)) + phase_pattern = np.zeros((self.n_y, self.n_x), "float32") + measurements = np.zeros( + (self.n_y, self.n_x, self.phase_steps, *self.feedback.data_shape) + ) for y in range(self.n_y): for x in range(self.n_x): diff --git a/openwfs/algorithms/troubleshoot.py b/openwfs/algorithms/troubleshoot.py index 662f74d..4ae3cd4 100644 --- a/openwfs/algorithms/troubleshoot.py +++ b/openwfs/algorithms/troubleshoot.py @@ -47,8 +47,9 @@ def cnr(signal_with_noise: np.ndarray, noise: np.ndarray) -> np.float64: return signal_std(signal_with_noise, noise) / noise.std() -def contrast_enhancement(signal_with_noise: np.ndarray, reference_with_noise: np.ndarray, - noise: np.ndarray) -> float: +def contrast_enhancement( + signal_with_noise: np.ndarray, reference_with_noise: np.ndarray, noise: np.ndarray +) -> float: """ Compute noise corrected contrast enhancement. The noise is assumed to be uncorrelated with the signal, such that var(measured) = var(signal) + var(noise). @@ -124,7 +125,9 @@ def frame_correlation(a: np.ndarray, b: np.ndarray) -> float: return np.mean(a * b) / (np.mean(a) * np.mean(b)) - 1 -def pearson_correlation(a: np.ndarray, b: np.ndarray, noise_var: np.ndarray = 0.0) -> float: +def pearson_correlation( + a: np.ndarray, b: np.ndarray, noise_var: np.ndarray = 0.0 +) -> float: """ Compute Pearson correlation. @@ -160,9 +163,19 @@ class StabilityResult: framestack: 3D array containing all recorded frames. Is None unless saving frames was requested. """ - def __init__(self, pixel_shifts_first, correlations_first, correlations_disattenuated_first, contrast_ratios_first, - pixel_shifts_prev, correlations_prev, correlations_disattenuated_prev, contrast_ratios_prev, - abs_timestamps, framestack): + def __init__( + self, + pixel_shifts_first, + correlations_first, + correlations_disattenuated_first, + contrast_ratios_first, + pixel_shifts_prev, + correlations_prev, + correlations_disattenuated_prev, + contrast_ratios_prev, + abs_timestamps, + framestack, + ): # Comparison with first frame self.pixel_shifts_first = pixel_shifts_first self.correlations_first = correlations_first @@ -186,47 +199,66 @@ def plot(self): """ # Comparisons with first frame plt.figure() - plt.plot(self.timestamps, self.pixel_shifts_first, '.-', label='image-shift (pix)') - plt.title('Stability - Image shift w.r.t. first frame') - plt.ylabel('Image shift (pix)') - plt.xlabel('time (s)') + plt.plot( + self.timestamps, self.pixel_shifts_first, ".-", label="image-shift (pix)" + ) + plt.title("Stability - Image shift w.r.t. first frame") + plt.ylabel("Image shift (pix)") + plt.xlabel("time (s)") plt.figure() - plt.plot(self.timestamps, self.correlations_first, '.-', label='correlation') - plt.plot(self.timestamps, self.correlations_disattenuated_first, '.-', label='correlation disattenuated') - plt.title('Stability - Correlation with first frame') - plt.xlabel('time (s)') + plt.plot(self.timestamps, self.correlations_first, ".-", label="correlation") + plt.plot( + self.timestamps, + self.correlations_disattenuated_first, + ".-", + label="correlation disattenuated", + ) + plt.title("Stability - Correlation with first frame") + plt.xlabel("time (s)") plt.legend() plt.figure() - plt.plot(self.timestamps, self.contrast_ratios_first, '.-', label='contrast ratio') - plt.title('Stability - Contrast ratio with first frame') - plt.xlabel('time (s)') + plt.plot( + self.timestamps, self.contrast_ratios_first, ".-", label="contrast ratio" + ) + plt.title("Stability - Contrast ratio with first frame") + plt.xlabel("time (s)") # Comparisons with previous frame plt.figure() - plt.plot(self.timestamps, self.pixel_shifts_prev, '.-', label='image-shift (pix)') - plt.title('Stability - Image shift w.r.t. previous frame') - plt.ylabel('Image shift (pix)') - plt.xlabel('time (s)') + plt.plot( + self.timestamps, self.pixel_shifts_prev, ".-", label="image-shift (pix)" + ) + plt.title("Stability - Image shift w.r.t. previous frame") + plt.ylabel("Image shift (pix)") + plt.xlabel("time (s)") plt.figure() - plt.plot(self.timestamps, self.correlations_prev, '.-', label='correlation') - plt.plot(self.timestamps, self.correlations_disattenuated_prev, '.-', label='correlation disattenuated') - plt.title('Stability - Correlation with previous frame') - plt.xlabel('time (s)') + plt.plot(self.timestamps, self.correlations_prev, ".-", label="correlation") + plt.plot( + self.timestamps, + self.correlations_disattenuated_prev, + ".-", + label="correlation disattenuated", + ) + plt.title("Stability - Correlation with previous frame") + plt.xlabel("time (s)") plt.legend() plt.figure() - plt.plot(self.timestamps, self.contrast_ratios_prev, '.-', label='contrast ratio') - plt.title('Stability - Contrast ratio with previous frame') - plt.xlabel('time (s)') + plt.plot( + self.timestamps, self.contrast_ratios_prev, ".-", label="contrast ratio" + ) + plt.title("Stability - Contrast ratio with previous frame") + plt.xlabel("time (s)") plt.show() -def measure_setup_stability(frame_source, sleep_time_s, num_of_frames, dark_frame, - do_save_frames=False) -> StabilityResult: +def measure_setup_stability( + frame_source, sleep_time_s, num_of_frames, dark_frame, do_save_frames=False +) -> StabilityResult: """Test the setup stability by repeatedly reading frames.""" first_frame = frame_source.read() prev_frame = first_frame @@ -260,14 +292,22 @@ def measure_setup_stability(frame_source, sleep_time_s, num_of_frames, dark_fram # Compare with first frame pixel_shifts_first[n, :] = find_pixel_shift(first_frame, new_frame) correlations_first[n] = pearson_correlation(first_frame, new_frame) - correlations_disattenuated_first[n] = pearson_correlation(first_frame, new_frame, noise_var=dark_var) - contrast_ratios_first[n] = contrast_enhancement(new_frame, first_frame, dark_frame) + correlations_disattenuated_first[n] = pearson_correlation( + first_frame, new_frame, noise_var=dark_var + ) + contrast_ratios_first[n] = contrast_enhancement( + new_frame, first_frame, dark_frame + ) # Compare with previous frame pixel_shifts_prev[n, :] = find_pixel_shift(prev_frame, new_frame) correlations_prev[n] = pearson_correlation(prev_frame, new_frame) - correlations_disattenuated_prev[n] = pearson_correlation(prev_frame, new_frame, noise_var=dark_var) - contrast_ratios_prev[n] = contrast_enhancement(new_frame, prev_frame, dark_frame) + correlations_disattenuated_prev[n] = pearson_correlation( + prev_frame, new_frame, noise_var=dark_var + ) + contrast_ratios_prev[n] = contrast_enhancement( + new_frame, prev_frame, dark_frame + ) abs_timestamps[n] = time.perf_counter() # Save frame if requested @@ -276,19 +316,23 @@ def measure_setup_stability(frame_source, sleep_time_s, num_of_frames, dark_fram prev_frame = new_frame - return StabilityResult(pixel_shifts_first=pixel_shifts_first, - correlations_first=correlations_first, - correlations_disattenuated_first=correlations_disattenuated_first, - contrast_ratios_first=contrast_ratios_first, - pixel_shifts_prev=pixel_shifts_prev, - correlations_prev=correlations_prev, - correlations_disattenuated_prev=correlations_disattenuated_prev, - contrast_ratios_prev=contrast_ratios_prev, - abs_timestamps=abs_timestamps, - framestack=framestack) - - -def measure_modulated_light_dual_phase_stepping(slm: PhaseSLM, feedback: Detector, phase_steps: int, num_blocks: int): + return StabilityResult( + pixel_shifts_first=pixel_shifts_first, + correlations_first=correlations_first, + correlations_disattenuated_first=correlations_disattenuated_first, + contrast_ratios_first=contrast_ratios_first, + pixel_shifts_prev=pixel_shifts_prev, + correlations_prev=correlations_prev, + correlations_disattenuated_prev=correlations_disattenuated_prev, + contrast_ratios_prev=contrast_ratios_prev, + abs_timestamps=abs_timestamps, + framestack=framestack, + ) + + +def measure_modulated_light_dual_phase_stepping( + slm: PhaseSLM, feedback: Detector, phase_steps: int, num_blocks: int +): """ Measure the ratio of modulated light with the dual phase stepping method. @@ -323,12 +367,16 @@ def measure_modulated_light_dual_phase_stepping(slm: PhaseSLM, feedback: Detecto measurements[p, q] = feedback.read() # 2D Fourier transform the modulation measurements - f = np.fft.fft2(measurements) / phase_steps ** 2 + f = np.fft.fft2(measurements) / phase_steps**2 # Compute fidelity factor due to modulated light eps = 1e-6 # Epsilon term to prevent division by zero - m1_m2_ratio = (np.abs(f[0, 1]) ** 2 + eps) / (np.abs(f[1, 0]) ** 2 + eps) # Ratio of modulated intensities - fidelity_modulated = (1 + m1_m2_ratio) / (1 + m1_m2_ratio + np.abs(f[0, 1]) ** 2 / np.abs(f[1, -1]) ** 2) + m1_m2_ratio = (np.abs(f[0, 1]) ** 2 + eps) / ( + np.abs(f[1, 0]) ** 2 + eps + ) # Ratio of modulated intensities + fidelity_modulated = (1 + m1_m2_ratio) / ( + 1 + m1_m2_ratio + np.abs(f[0, 1]) ** 2 / np.abs(f[1, -1]) ** 2 + ) return fidelity_modulated @@ -362,7 +410,9 @@ def measure_modulated_light(slm: PhaseSLM, feedback: Detector, phase_steps: int) f = np.fft.fft(measurements) # Compute ratio of modulated light over total - fidelity_modulated = 0.5 * (1.0 + np.sqrt(np.clip(1.0 - 4.0 * np.abs(f[1] / f[0]) ** 2, 0, None))) + fidelity_modulated = 0.5 * ( + 1.0 + np.sqrt(np.clip(1.0 - 4.0 * np.abs(f[1] / f[0]) ** 2, 0, None)) + ) return fidelity_modulated @@ -438,72 +488,88 @@ def report(self, do_plots=True): Args: do_plots (bool): Plot some results as graphs. """ - print(f'\n===========================') - print(f'{time.ctime(self.timestamp)}\n') - print(f'=== Feedback metrics ===') - print(f'number of modes (N): {self.wfs_result.n:.3f}') - print(f'fidelity_amplitude: {self.wfs_result.fidelity_amplitude.squeeze():.3f}') - print(f'fidelity_noise: {self.wfs_result.fidelity_noise.squeeze():.3f}') - print(f'fidelity_non_modulated: {self.fidelity_non_modulated:.3f}') - print(f'fidelity_phase_calibration: {self.wfs_result.fidelity_calibration.squeeze():.3f}') - print(f'fidelity_decorrelation: {self.fidelity_decorrelation:.3f}') - print(f'expected enhancement: {self.expected_enhancement:.3f}') - print(f'measured enhancement: {self.measured_enhancement:.3f}') - print(f'') - print(f'=== Frame metrics ===') - print(f'signal std, before: {self.frame_signal_std_before:.2f}') - print(f'signal std, after: {self.frame_signal_std_after:.2f}') - print(f'signal std, with shaped wavefront: {self.frame_signal_std_shaped_wf:.2f}') + print(f"\n===========================") + print(f"{time.ctime(self.timestamp)}\n") + print(f"=== Feedback metrics ===") + print(f"number of modes (N): {self.wfs_result.n:.3f}") + print(f"fidelity_amplitude: {self.wfs_result.fidelity_amplitude.squeeze():.3f}") + print(f"fidelity_noise: {self.wfs_result.fidelity_noise.squeeze():.3f}") + print(f"fidelity_non_modulated: {self.fidelity_non_modulated:.3f}") + print( + f"fidelity_phase_calibration: {self.wfs_result.fidelity_calibration.squeeze():.3f}" + ) + print(f"fidelity_decorrelation: {self.fidelity_decorrelation:.3f}") + print(f"expected enhancement: {self.expected_enhancement:.3f}") + print(f"measured enhancement: {self.measured_enhancement:.3f}") + print(f"") + print(f"=== Frame metrics ===") + print(f"signal std, before: {self.frame_signal_std_before:.2f}") + print(f"signal std, after: {self.frame_signal_std_after:.2f}") + print( + f"signal std, with shaped wavefront: {self.frame_signal_std_shaped_wf:.2f}" + ) if self.dark_frame is not None: - print(f'average offset (dark frame): {self.dark_frame.mean():.2f}') - print(f'median offset (dark frame): {np.median(self.dark_frame):.2f}') - print(f'noise std (dark frame): {np.std(self.dark_frame):.2f}') - print(f'frame repeatability: {self.frame_repeatability:.3f}') - print(f'contrast to noise ratio before: {self.frame_cnr_before:.3f}') - print(f'contrast to noise ratio after: {self.frame_cnr_after:.3f}') - print(f'contrast to noise ratio with shaped wavefront: {self.frame_cnr_shaped_wf:.3f}') - print(f'contrast enhancement: {self.frame_contrast_enhancement:.3f}') - print(f'photobleaching ratio: {self.frame_photobleaching_ratio:.3f}') + print(f"average offset (dark frame): {self.dark_frame.mean():.2f}") + print(f"median offset (dark frame): {np.median(self.dark_frame):.2f}") + print(f"noise std (dark frame): {np.std(self.dark_frame):.2f}") + print(f"frame repeatability: {self.frame_repeatability:.3f}") + print(f"contrast to noise ratio before: {self.frame_cnr_before:.3f}") + print(f"contrast to noise ratio after: {self.frame_cnr_after:.3f}") + print( + f"contrast to noise ratio with shaped wavefront: {self.frame_cnr_shaped_wf:.3f}" + ) + print(f"contrast enhancement: {self.frame_contrast_enhancement:.3f}") + print(f"photobleaching ratio: {self.frame_photobleaching_ratio:.3f}") if do_plots and self.stability is not None: self.stability.plot() - if (do_plots and self.dark_frame is not None and self.after_frame is not None and - self.shaped_wf_frame is not None): + if ( + do_plots + and self.dark_frame is not None + and self.after_frame is not None + and self.shaped_wf_frame is not None + ): max_value = max(self.after_frame.max(), self.shaped_wf_frame.max()) # Plot dark frame plt.figure() plt.imshow(self.dark_frame, vmin=0, vmax=max_value) - plt.title('Dark frame') + plt.title("Dark frame") plt.colorbar() - plt.xlabel('x (pix)') - plt.ylabel('y (pix)') + plt.xlabel("x (pix)") + plt.ylabel("y (pix)") plt.figure() # Plot after frame with flat wf plt.imshow(self.after_frame, vmin=0, vmax=max_value) - plt.title('Frame with flat wavefront') + plt.title("Frame with flat wavefront") plt.colorbar() - plt.xlabel('x (pix)') - plt.ylabel('y (pix)') + plt.xlabel("x (pix)") + plt.ylabel("y (pix)") # Plot shaped wf frame plt.figure() plt.imshow(self.shaped_wf_frame, vmin=0, vmax=max_value) - plt.title('Frame with shaped wavefront') + plt.title("Frame with shaped wavefront") plt.colorbar() - plt.xlabel('x (pix)') - plt.ylabel('y (pix)') + plt.xlabel("x (pix)") + plt.ylabel("y (pix)") plt.show() -def troubleshoot(algorithm, background_feedback: Detector, frame_source: Detector, shutter, - do_frame_capture=True, do_long_stability_test=False, - stability_sleep_time_s=0.5, - stability_num_of_frames=500, - stability_do_save_frames=False, - measure_non_modulated_phase_steps=16) -> WFSTroubleshootResult: +def troubleshoot( + algorithm, + background_feedback: Detector, + frame_source: Detector, + shutter, + do_frame_capture=True, + do_long_stability_test=False, + stability_sleep_time_s=0.5, + stability_num_of_frames=500, + stability_do_save_frames=False, + measure_non_modulated_phase_steps=16, +) -> WFSTroubleshootResult: """ Run a series of basic checks to find common sources of error in a WFS experiment. Quantifies several types of fidelity reduction. @@ -532,7 +598,7 @@ def troubleshoot(algorithm, background_feedback: Detector, frame_source: Detecto trouble = WFSTroubleshootResult() if do_frame_capture: - logging.info('Capturing frames before WFS...') + logging.info("Capturing frames before WFS...") # Capture frames before WFS algorithm.slm.set_phases(0.0) # Flat wavefront @@ -543,12 +609,16 @@ def troubleshoot(algorithm, background_feedback: Detector, frame_source: Detecto before_frame_2 = frame_source.read() # Frame metrics - trouble.frame_signal_std_before = signal_std(trouble.before_frame, trouble.dark_frame) + trouble.frame_signal_std_before = signal_std( + trouble.before_frame, trouble.dark_frame + ) trouble.frame_cnr_before = cnr(trouble.before_frame, trouble.dark_frame) - trouble.frame_repeatability = pearson_correlation(trouble.before_frame, before_frame_2) + trouble.frame_repeatability = pearson_correlation( + trouble.before_frame, before_frame_2 + ) if do_long_stability_test and do_frame_capture: - logging.info('Run long stability test...') + logging.info("Run long stability test...") # Test setup stability trouble.stability = measure_setup_stability( @@ -556,12 +626,13 @@ def troubleshoot(algorithm, background_feedback: Detector, frame_source: Detecto sleep_time_s=stability_sleep_time_s, num_of_frames=stability_num_of_frames, dark_frame=trouble.dark_frame, - do_save_frames=stability_do_save_frames) + do_save_frames=stability_do_save_frames, + ) trouble.feedback_before = algorithm.feedback.read() # WFS experiment - logging.info('Run WFS algorithm...') + logging.info("Run WFS algorithm...") trouble.wfs_result = algorithm.execute() # Execute WFS algorithm # Flat wavefront @@ -570,39 +641,61 @@ def troubleshoot(algorithm, background_feedback: Detector, frame_source: Detecto trouble.feedback_after = algorithm.feedback.read() if do_frame_capture: - logging.info('Capturing frames after WFS...') + logging.info("Capturing frames after WFS...") trouble.after_frame = frame_source.read() # After frame (flat wf) # Shaped wavefront algorithm.slm.set_phases(-np.angle(trouble.wfs_result.t)) trouble.feedback_shaped_wf = algorithm.feedback.read() - trouble.measured_enhancement = trouble.feedback_shaped_wf / trouble.average_background + trouble.measured_enhancement = ( + trouble.feedback_shaped_wf / trouble.average_background + ) if do_frame_capture: trouble.shaped_wf_frame = frame_source.read() # Shaped wavefront frame # Frame metrics - logging.info('Compute frame metrics...') - trouble.frame_signal_std_after = signal_std(trouble.after_frame, trouble.dark_frame) - trouble.frame_signal_std_shaped_wf = signal_std(trouble.shaped_wf_frame, trouble.dark_frame) - trouble.frame_cnr_after = cnr(trouble.after_frame, trouble.dark_frame) # Frame CNR after - trouble.frame_cnr_shaped_wf = cnr(trouble.shaped_wf_frame, trouble.dark_frame) # Frame CNR shaped wf - trouble.frame_contrast_enhancement = \ - contrast_enhancement(trouble.shaped_wf_frame, trouble.after_frame, trouble.dark_frame) - trouble.frame_photobleaching_ratio = \ - contrast_enhancement(trouble.after_frame, trouble.before_frame, trouble.dark_frame) - trouble.fidelity_decorrelation = \ - pearson_correlation(trouble.before_frame, trouble.after_frame, noise_var=trouble.dark_frame.var()) - - trouble.fidelity_non_modulated = \ - measure_modulated_light(slm=algorithm.slm, feedback=algorithm.feedback, - phase_steps=measure_non_modulated_phase_steps) + logging.info("Compute frame metrics...") + trouble.frame_signal_std_after = signal_std( + trouble.after_frame, trouble.dark_frame + ) + trouble.frame_signal_std_shaped_wf = signal_std( + trouble.shaped_wf_frame, trouble.dark_frame + ) + trouble.frame_cnr_after = cnr( + trouble.after_frame, trouble.dark_frame + ) # Frame CNR after + trouble.frame_cnr_shaped_wf = cnr( + trouble.shaped_wf_frame, trouble.dark_frame + ) # Frame CNR shaped wf + trouble.frame_contrast_enhancement = contrast_enhancement( + trouble.shaped_wf_frame, trouble.after_frame, trouble.dark_frame + ) + trouble.frame_photobleaching_ratio = contrast_enhancement( + trouble.after_frame, trouble.before_frame, trouble.dark_frame + ) + trouble.fidelity_decorrelation = pearson_correlation( + trouble.before_frame, + trouble.after_frame, + noise_var=trouble.dark_frame.var(), + ) + + trouble.fidelity_non_modulated = measure_modulated_light( + slm=algorithm.slm, + feedback=algorithm.feedback, + phase_steps=measure_non_modulated_phase_steps, + ) trouble.expected_enhancement = np.squeeze( - trouble.wfs_result.n * trouble.wfs_result.fidelity_amplitude * trouble.wfs_result.fidelity_noise - * trouble.fidelity_non_modulated * trouble.wfs_result.fidelity_calibration * trouble.fidelity_decorrelation) + trouble.wfs_result.n + * trouble.wfs_result.fidelity_amplitude + * trouble.wfs_result.fidelity_noise + * trouble.fidelity_non_modulated + * trouble.wfs_result.fidelity_calibration + * trouble.fidelity_decorrelation + ) # Analyze the WFS result - logging.info('Analyze WFS result...') + logging.info("Analyze WFS result...") return trouble diff --git a/openwfs/algorithms/utilities.py b/openwfs/algorithms/utilities.py index 2b9c9e3..ec2d040 100644 --- a/openwfs/algorithms/utilities.py +++ b/openwfs/algorithms/utilities.py @@ -27,15 +27,17 @@ class WFSResult: This is the offset that is caused by a bias in the detector signal, stray light, etc. Default value: 0.0. """ - def __init__(self, - t: np.ndarray, - t_f: np.ndarray, - axis: int, - fidelity_noise: ArrayLike, - fidelity_amplitude: ArrayLike, - fidelity_calibration: ArrayLike, - n: Optional[int] = None, - intensity_offset: Optional[ArrayLike] = 0.0): + def __init__( + self, + t: np.ndarray, + t_f: np.ndarray, + axis: int, + fidelity_noise: ArrayLike, + fidelity_amplitude: ArrayLike, + fidelity_calibration: ArrayLike, + n: Optional[int] = None, + intensity_offset: Optional[ArrayLike] = 0.0, + ): """ Args: t(ndarray): measured transmission matrix. @@ -65,20 +67,39 @@ def __init__(self, self.fidelity_amplitude = np.atleast_1d(fidelity_amplitude) self.fidelity_calibration = np.atleast_1d(fidelity_calibration) self.estimated_enhancement = np.atleast_1d( - 1.0 + (self.n - 1) * self.fidelity_amplitude * self.fidelity_noise * self.fidelity_calibration) - self.intensity_offset = intensity_offset * np.ones(self.fidelity_calibration.shape) if np.isscalar( - intensity_offset) \ + 1.0 + + (self.n - 1) + * self.fidelity_amplitude + * self.fidelity_noise + * self.fidelity_calibration + ) + self.intensity_offset = ( + intensity_offset * np.ones(self.fidelity_calibration.shape) + if np.isscalar(intensity_offset) else intensity_offset - after = np.sum(np.abs(t), tuple( - range(self.axis))) ** 2 * self.fidelity_noise * self.fidelity_calibration + intensity_offset + ) + after = ( + np.sum(np.abs(t), tuple(range(self.axis))) ** 2 + * self.fidelity_noise + * self.fidelity_calibration + + intensity_offset + ) self.estimated_optimized_intensity = np.atleast_1d(after) def __str__(self) -> str: - noise_warning = "OK" if self.fidelity_noise > 0.5 else "WARNING low signal quality." - amplitude_warning = "OK" if self.fidelity_amplitude > 0.5 else "WARNING uneven contribution of optical modes." - calibration_fidelity_warning = "OK" if self.fidelity_calibration > 0.5 else ( - "WARNING non-linear phase response, check " - "lookup table.") + noise_warning = ( + "OK" if self.fidelity_noise > 0.5 else "WARNING low signal quality." + ) + amplitude_warning = ( + "OK" + if self.fidelity_amplitude > 0.5 + else "WARNING uneven contribution of optical modes." + ) + calibration_fidelity_warning = ( + "OK" + if self.fidelity_calibration > 0.5 + else ("WARNING non-linear phase response, check " "lookup table.") + ) return f""" Wavefront shaping results: fidelity_noise: {self.fidelity_noise} {noise_warning} @@ -88,7 +109,7 @@ def __str__(self) -> str: estimated_optimized_intensity: {self.estimated_optimized_intensity} """ - def select_target(self, b) -> 'WFSResult': + def select_target(self, b) -> "WFSResult": """ Returns the wavefront shaping results for a single target @@ -98,18 +119,21 @@ def select_target(self, b) -> 'WFSResult': Returns: WFSResults data for the specified target """ - return WFSResult(t=self.t.reshape((*self.t.shape[0:2], -1))[:, :, b], - t_f=self.t_f.reshape((*self.t_f.shape[0:2], -1))[:, :, b], - axis=self.axis, - intensity_offset=self.intensity_offset[:][b], - fidelity_noise=self.fidelity_noise[:][b], - fidelity_amplitude=self.fidelity_amplitude[:][b], - fidelity_calibration=self.fidelity_calibration[:][b], - n=self.n, - ) - - -def analyze_phase_stepping(measurements: np.ndarray, axis: int, A: Optional[float] = None): + return WFSResult( + t=self.t.reshape((*self.t.shape[0:2], -1))[:, :, b], + t_f=self.t_f.reshape((*self.t_f.shape[0:2], -1))[:, :, b], + axis=self.axis, + intensity_offset=self.intensity_offset[:][b], + fidelity_noise=self.fidelity_noise[:][b], + fidelity_amplitude=self.fidelity_amplitude[:][b], + fidelity_calibration=self.fidelity_calibration[:][b], + n=self.n, + ) + + +def analyze_phase_stepping( + measurements: np.ndarray, axis: int, A: Optional[float] = None +): """Analyzes the result of phase stepping measurements, returning matrix `t` and noise statistics This function assumes that all measurements were made using the same reference field `A` @@ -166,7 +190,9 @@ def analyze_phase_stepping(measurements: np.ndarray, axis: int, A: Optional[floa # compute the effect of amplitude variations. # for perfectly developed speckle, and homogeneous illumination, this factor will be pi/4 - amplitude_factor = np.mean(np.abs(t), segments) ** 2 / np.mean(np.abs(t) ** 2, segments) + amplitude_factor = np.mean(np.abs(t), segments) ** 2 / np.mean( + np.abs(t) ** 2, segments + ) # estimate the calibration error # we first construct a matrix that can be used to fit @@ -190,15 +216,26 @@ def analyze_phase_stepping(measurements: np.ndarray, axis: int, A: Optional[floa total_energy = np.sum(np.abs(t_f) ** 2) if phase_steps > 3: - noise_energy = (total_energy - signal_energy - offset_energy) / (phase_steps - 3) - noise_factor = np.abs(np.maximum(signal_energy - noise_energy, 0.0) / signal_energy) + noise_energy = (total_energy - signal_energy - offset_energy) / ( + phase_steps - 3 + ) + noise_factor = np.abs( + np.maximum(signal_energy - noise_energy, 0.0) / signal_energy + ) else: noise_factor = 1.0 # cannot estimate reliably calibration_fidelity = np.abs(c[1]) ** 2 / np.sum(np.abs(c[1:]) ** 2) - return WFSResult(t, t_f=t_f, axis=axis, fidelity_amplitude=amplitude_factor, fidelity_noise=noise_factor, - fidelity_calibration=calibration_fidelity, n=n) + return WFSResult( + t, + t_f=t_f, + axis=axis, + fidelity_amplitude=amplitude_factor, + fidelity_noise=noise_factor, + fidelity_calibration=calibration_fidelity, + n=n, + ) class WFSController: @@ -264,7 +301,9 @@ def wavefront(self, value): self._amplitude_factor = result.fidelity_amplitude self._estimated_enhancement = result.estimated_enhancement self._calibration_fidelity = result.fidelity_calibration - self._estimated_optimized_intensity = result.estimated_optimized_intensity + self._estimated_optimized_intensity = ( + result.estimated_optimized_intensity + ) self._snr = 1.0 / (1.0 / result.fidelity_noise - 1.0) self._result = result self.algorithm.slm.set_phases(self._optimized_wavefront) @@ -323,12 +362,12 @@ def snr(self) -> float: @property def recompute_wavefront(self) -> bool: - """Returns: bool that indicates whether the wavefront needs to be recomputed. """ + """Returns: bool that indicates whether the wavefront needs to be recomputed.""" return self._recompute_wavefront @recompute_wavefront.setter def recompute_wavefront(self, value): - """Sets the bool that indicates whether the wavefront needs to be recomputed. """ + """Sets the bool that indicates whether the wavefront needs to be recomputed.""" self._recompute_wavefront = value @property @@ -355,6 +394,8 @@ def test_wavefront(self, value): feedback_flat = self.algorithm.feedback.read().copy() self.wavefront = WFSController.State.SHAPED_WAVEFRONT feedback_shaped = self.algorithm.feedback.read().copy() - self._feedback_enhancement = float(feedback_shaped.sum() / feedback_flat.sum()) + self._feedback_enhancement = float( + feedback_shaped.sum() / feedback_flat.sum() + ) self._test_wavefront = value diff --git a/openwfs/core.py b/openwfs/core.py index 8657b37..4b2756f 100644 --- a/openwfs/core.py +++ b/openwfs/core.py @@ -17,12 +17,21 @@ class Device(ABC): """Base class for detectors and actuators - See :ref:`key_concepts` for more information. + See :ref:`key_concepts` for more information. """ - __slots__ = ('_end_time_ns', '_timeout_margin', '_locking_thread', '_error', - '__weakref__', '_latency', '_duration', '_multi_threaded') - _workers = ThreadPoolExecutor(thread_name_prefix='Device._workers') + + __slots__ = ( + "_end_time_ns", + "_timeout_margin", + "_locking_thread", + "_error", + "__weakref__", + "_latency", + "_duration", + "_multi_threaded", + ) + _workers = ThreadPoolExecutor(thread_name_prefix="Device._workers") _moving = False _state_lock = threading.Lock() _devices: "Set[Device]" = WeakSet() @@ -62,15 +71,25 @@ def _start(self): else: logging.debug("switch to MOVING requested by %s.", self) - same_type = [device for device in Device._devices if device._is_actuator == self._is_actuator] - other_type = [device for device in Device._devices if device._is_actuator != self._is_actuator] + same_type = [ + device + for device in Device._devices + if device._is_actuator == self._is_actuator + ] + other_type = [ + device + for device in Device._devices + if device._is_actuator != self._is_actuator + ] # compute the minimum latency of same_type # for instance, when switching to 'measuring', this number tells us how long it takes before any of the # detectors actually starts a measurement. # If this is a positive number, we can make the switch to 'measuring' slightly _before_ # all actuators have stabilized. - latency = min([device.latency for device in same_type], default=0.0 * u.ns) # noqa - incorrect warning + latency = min( + [device.latency for device in same_type], default=0.0 * u.ns + ) # noqa - incorrect warning # wait until all devices of the other type have (almost) finished for device in other_type: @@ -85,7 +104,11 @@ def _start(self): # also store the time we expect the operation to finish # note: it may finish slightly earlier since (latency + duration) is a maximum value - self._end_time_ns = time.time_ns() + self.latency.to_value(u.ns) + self.duration.to_value(u.ns) + self._end_time_ns = ( + time.time_ns() + + self.latency.to_value(u.ns) + + self.duration.to_value(u.ns) + ) @property def latency(self) -> Quantity[u.ms]: @@ -173,13 +196,15 @@ def wait(self, up_to: Optional[Quantity[u.ms]] = None) -> None: while self.busy(): time.sleep(0.01) if time.time_ns() - start > timeout: - raise TimeoutError("Timeout in %s (tid %i)", self, threading.get_ident()) + raise TimeoutError( + "Timeout in %s (tid %i)", self, threading.get_ident() + ) else: time_to_wait = self._end_time_ns - time.time_ns() if up_to is not None: time_to_wait -= up_to.to_value(u.ns) if time_to_wait > 0: - time.sleep(time_to_wait / 1.0E9) + time.sleep(time_to_wait / 1.0e9) def busy(self) -> bool: """Returns true if the device is measuring or moving (see `wait()`). @@ -210,8 +235,8 @@ def timeout(self, value): class Actuator(Device, ABC): - """Base class for all actuators - """ + """Base class for all actuators""" + __slots__ = () @final @@ -224,10 +249,23 @@ class Detector(Device, ABC): See :numref:`Detectors` in the documentation for more information. """ - __slots__ = ('_measurements_pending', '_lock_condition', '_pixel_size', '_data_shape') - def __init__(self, *, data_shape: Optional[tuple[int, ...]], pixel_size: Optional[Quantity], - duration: Optional[Quantity[u.ms]], latency: Optional[Quantity[u.ms]], multi_threaded: bool = True): + __slots__ = ( + "_measurements_pending", + "_lock_condition", + "_pixel_size", + "_data_shape", + ) + + def __init__( + self, + *, + data_shape: Optional[tuple[int, ...]], + pixel_size: Optional[Quantity], + duration: Optional[Quantity[u.ms]], + latency: Optional[Quantity[u.ms]], + multi_threaded: bool = True + ): """ Constructor for the Detector class. @@ -362,11 +400,19 @@ def __do_fetch(self, out_, *args_, **kwargs_): """Helper function that awaits all futures in the keyword argument list, and then calls _fetch""" try: if len(args_) > 0 or len(kwargs_) > 0: - logging.debug("awaiting inputs for %s (tid: %i).", self, threading.get_ident()) - awaited_args = [(arg.result() if isinstance(arg, Future) else arg) for arg in args_] - awaited_kwargs = {key: (arg.result() if isinstance(arg, Future) else arg) for (key, arg) in - kwargs_.items()} - logging.debug("fetching data of %s ((tid: %i)).", self, threading.get_ident()) + logging.debug( + "awaiting inputs for %s (tid: %i).", self, threading.get_ident() + ) + awaited_args = [ + (arg.result() if isinstance(arg, Future) else arg) for arg in args_ + ] + awaited_kwargs = { + key: (arg.result() if isinstance(arg, Future) else arg) + for (key, arg) in kwargs_.items() + } + logging.debug( + "fetching data of %s ((tid: %i)).", self, threading.get_ident() + ) data = self._fetch(*awaited_args, **awaited_kwargs) data = set_pixel_size(data, self.pixel_size) assert data.shape == self.data_shape @@ -404,7 +450,7 @@ def __setattr__(self, key, value): """ # note: the check needs to be in this order, otherwise we cannot initialize set _multi_threaded - if not key.startswith('_') and self._multi_threaded: + if not key.startswith("_") and self._multi_threaded: with self._lock_condition: while self._measurements_pending > 0: self._lock_condition.wait() @@ -477,10 +523,16 @@ def coordinates(self, dimension: int) -> Quantity: Args: dimension: Dimension for which to return the coordinates. """ - unit = u.dimensionless_unscaled if self.pixel_size is None else self.pixel_size[dimension] + unit = ( + u.dimensionless_unscaled + if self.pixel_size is None + else self.pixel_size[dimension] + ) shape = np.ones_like(self.data_shape) shape[dimension] = self.data_shape[dimension] - return np.arange(0.5, 0.5 + self.data_shape[dimension], 1.0).reshape(shape) * unit + return ( + np.arange(0.5, 0.5 + self.data_shape[dimension], 1.0).reshape(shape) * unit + ) @final @property @@ -515,19 +567,30 @@ def __init__(self, *args, multi_threaded: bool): # when the settings of one of the source detectors is changed. # Therefore, we pass 'None' for all parameters, and override # data_shape, pixel_size, duration and latency in the properties. - super().__init__(data_shape=None, pixel_size=None, duration=None, latency=None, multi_threaded=multi_threaded) + super().__init__( + data_shape=None, + pixel_size=None, + duration=None, + latency=None, + multi_threaded=multi_threaded, + ) def trigger(self, *args, immediate=False, **kwargs): """Triggers all sources at the same time (regardless of latency), and schedules a call to `_fetch()`""" - future_data = [(source.trigger(immediate=immediate) if source is not None else None) for source in - self._sources] + future_data = [ + (source.trigger(immediate=immediate) if source is not None else None) + for source in self._sources + ] return super().trigger(*future_data, *args, **kwargs) @final @property def latency(self) -> Quantity[u.ms]: """Returns the shortest latency for all detectors.""" - return min((source.latency for source in self._sources if source is not None), default=0.0 * u.ms) + return min( + (source.latency for source in self._sources if source is not None), + default=0.0 * u.ms, + ) @final @property @@ -538,11 +601,16 @@ def duration(self) -> Quantity[u.ms]: Note that `latency` is allowed to vary over time for devices that can only be triggered periodically, so this `duration` may also vary over time. """ - times = [(source.duration, source.latency) for source in self._sources if source is not None] + times = [ + (source.duration, source.latency) + for source in self._sources + if source is not None + ] if len(times) == 0: return 0.0 * u.ms - return (max([duration + latency for (duration, latency) in times]) - - min([latency for (duration, latency) in times])) + return max([duration + latency for (duration, latency) in times]) - min( + [latency for (duration, latency) in times] + ) @property def data_shape(self): @@ -556,8 +624,8 @@ def pixel_size(self) -> Optional[Quantity]: class PhaseSLM(ABC): - """Base class for phase-only SLMs - """ + """Base class for phase-only SLMs""" + __slots__ = () @abstractmethod diff --git a/openwfs/devices/camera.py b/openwfs/devices/camera.py index c28a7d9..3e35550 100644 --- a/openwfs/devices/camera.py +++ b/openwfs/devices/camera.py @@ -13,7 +13,8 @@ ```pip install harvesters``` Alternatively, specify the genicam dependency when installing openwfs: ```pip install openwfs[genicam]``` - """) + """ + ) from ..core import Detector @@ -42,29 +43,37 @@ class Camera(Detector): >>> camera = Camera(cti_file=R"C:\\Program Files\\Basler\\pylon 7\\Runtime\\x64\\ProducerU3V.cti") >>> camera.exposure_time = 10 * u.ms >>> frame = camera.read() + """ + + def __init__( + self, + cti_file: str, + serial_number: Optional[str] = None, + multi_threaded=True, + **kwargs, + ): """ - - def __init__(self, cti_file: str, serial_number: Optional[str] = None, multi_threaded=True, **kwargs): - """ - Initialize the Camera object. - - Args: - cti_file: The path to the GenTL producer file. - This path depends on where the driver for the camera is installed. - For Basler cameras, this is typically located in - R"C:\\Program Files\\Basler\\pylon 7\\Runtime\\x64\\ProducerU3V.cti". - - serial_number: The serial number of the camera. - When omitted, the first camera found is selected. - **kwargs: Additional keyword arguments. - These arguments are transferred to the node map of the camera. + Initialize the Camera object. + + Args: + cti_file: The path to the GenTL producer file. + This path depends on where the driver for the camera is installed. + For Basler cameras, this is typically located in + R"C:\\Program Files\\Basler\\pylon 7\\Runtime\\x64\\ProducerU3V.cti". + + serial_number: The serial number of the camera. + When omitted, the first camera found is selected. + **kwargs: Additional keyword arguments. + These arguments are transferred to the node map of the camera. """ self._harvester = Harvester() self._harvester.add_file(cti_file, check_validity=True) self._harvester.update() # open the camera, use the serial_number to select the camera if it is specified. - search_key = {'serial_number': serial_number} if serial_number is not None else None + search_key = ( + {"serial_number": serial_number} if serial_number is not None else None + ) self._camera = self._harvester.create(search_key=search_key) nodes = self._camera.remote_device.node_map @@ -72,10 +81,10 @@ def __init__(self, cti_file: str, serial_number: Optional[str] = None, multi_thr # set triggering to 'Software', so that we can trigger the camera by calling `trigger`. # turn off auto exposure so that `duration` accurately reflects the required measurement time. - nodes.TriggerMode.value = 'On' - nodes.TriggerSource.value = 'Software' - nodes.ExposureMode.value = 'Timed' - nodes.ExposureAuto.value = 'Off' + nodes.TriggerMode.value = "On" + nodes.TriggerSource.value = "Software" + nodes.ExposureMode.value = "Timed" + nodes.ExposureAuto.value = "Off" nodes.BinningHorizontal.value = 1 nodes.BinningVertical.value = 1 nodes.OffsetX.value = 0 @@ -104,22 +113,30 @@ def __init__(self, cti_file: str, serial_number: Optional[str] = None, multi_thr try: setattr(nodes, key, value) except AttributeError: - print(f'Warning: could not set camera property {key} to {value}') + print(f"Warning: could not set camera property {key} to {value}") try: - pixel_size = [nodes.SensorPixelHeight.value, nodes.SensorPixelWidth.value] * u.um + pixel_size = [ + nodes.SensorPixelHeight.value, + nodes.SensorPixelWidth.value, + ] * u.um except AttributeError: # the SensorPixelWidth feature is optional pixel_size = None - super().__init__(multi_threaded=multi_threaded, data_shape=None, pixel_size=pixel_size, duration=None, - latency=0.0 * u.ms) + super().__init__( + multi_threaded=multi_threaded, + data_shape=None, + pixel_size=pixel_size, + duration=None, + latency=0.0 * u.ms, + ) self._camera.start() def __del__(self): - if hasattr(self, '_camera'): + if hasattr(self, "_camera"): self._camera.stop() self._camera.destroy() - if hasattr(self, '_harvester'): + if hasattr(self, "_harvester"): self._harvester.reset() def _do_trigger(self): @@ -138,7 +155,7 @@ def _fetch(self, *args, **kwargs) -> np.ndarray: buffer = self._camera.fetch() frame = buffer.payload.components[0].data.reshape(self.data_shape) if frame.size == 0: - raise Exception('Camera returned an empty frame') + raise Exception("Camera returned an empty frame") data = frame.copy() buffer.queue() # give back buffer to the camera driver return data diff --git a/openwfs/devices/galvo_scanner.py b/openwfs/devices/galvo_scanner.py index 8fd2e7e..0b5db7c 100644 --- a/openwfs/devices/galvo_scanner.py +++ b/openwfs/devices/galvo_scanner.py @@ -19,7 +19,8 @@ ```pip install nidaqmx``` Alternatively, specify the genicam dependency when installing openwfs: ```pip install openwfs[nidaq]``` - """) + """ + ) from ..core import Detector from ..utilities import unitless @@ -36,6 +37,7 @@ class InputChannel: terminal_configuration: The terminal configuration of the channel, defaults to `TerminalConfiguration.DEFAULT` """ + channel: str v_min: Quantity[u.V] v_max: Quantity[u.V] @@ -65,11 +67,12 @@ class Axis: terminal_configuration: The terminal configuration of the channel, defaults to `TerminalConfiguration.DEFAULT` """ + channel: str v_min: Quantity[u.V] v_max: Quantity[u.V] scale: Quantity[u.um / u.V] - maximum_acceleration: Quantity[u.V / u.s ** 2] + maximum_acceleration: Quantity[u.V / u.s**2] terminal_configuration: TerminalConfiguration = TerminalConfiguration.DEFAULT def to_volt(self, pos: Union[np.ndarray, float]) -> Quantity[u.V]: @@ -103,16 +106,22 @@ def maximum_scan_speed(self, linear_range: float): Quantity[u.V / u.s]: maximum scan speed """ # x = 0.5 · a · t² = 0.5 (v_max - v_min) · (1 - linear_range) - t_accel = np.sqrt((self.v_max - self.v_min) * (1 - linear_range) / self.maximum_acceleration) + t_accel = np.sqrt( + (self.v_max - self.v_min) * (1 - linear_range) / self.maximum_acceleration + ) hardware_limit = t_accel * self.maximum_acceleration # t_linear = linear_range · (v_max - v_min) / maximum_speed # t_accel = maximum_speed / maximum_acceleration # 0.5·t_linear == t_accel => 0.5·linear_range · (v_max-v_min) · maximum_acceleration = maximum_speed² - practical_limit = np.sqrt(0.5 * linear_range * (self.v_max - self.v_min) * self.maximum_acceleration) + practical_limit = np.sqrt( + 0.5 * linear_range * (self.v_max - self.v_min) * self.maximum_acceleration + ) return np.minimum(hardware_limit, practical_limit) - def step(self, start: float, stop: float, sample_rate: Quantity[u.Hz]) -> Quantity[u.V]: + def step( + self, start: float, stop: float, sample_rate: Quantity[u.Hz] + ) -> Quantity[u.V]: """ Generate a voltage sequence to move from `start` to `stop` in the fastest way possible. @@ -138,15 +147,23 @@ def step(self, start: float, stop: float, sample_rate: Quantity[u.Hz]) -> Quanti # `t` is measured in samples # `a` is measured in volt/sample² - a = self.maximum_acceleration / sample_rate ** 2 * np.sign(v_end - v_start) + a = self.maximum_acceleration / sample_rate**2 * np.sign(v_end - v_start) t_total = unitless(2.0 * np.sqrt((v_end - v_start) / a)) - t = np.arange(np.ceil(t_total + 1E-6)) # add a small number to deal with case t=0 (start=end) - v_accel = v_start + 0.5 * a * t[:len(t) // 2] ** 2 # acceleration part - v_decel = v_end - 0.5 * a * (t_total - t[len(t) // 2:]) ** 2 # deceleration part + t = np.arange( + np.ceil(t_total + 1e-6) + ) # add a small number to deal with case t=0 (start=end) + v_accel = v_start + 0.5 * a * t[: len(t) // 2] ** 2 # acceleration part + v_decel = ( + v_end - 0.5 * a * (t_total - t[len(t) // 2 :]) ** 2 + ) # deceleration part v_decel[-1] = v_end # fix last point because t may be > t_total due to rounding - return np.clip(np.concatenate((v_accel, v_decel)), self.v_min, self.v_max) # noqa ignore incorrect type warning + return np.clip( + np.concatenate((v_accel, v_decel)), self.v_min, self.v_max + ) # noqa ignore incorrect type warning - def scan(self, start: float, stop: float, sample_count: int, sample_rate: Quantity[u.Hz]): + def scan( + self, start: float, stop: float, sample_count: int, sample_rate: Quantity[u.Hz] + ): """ Generate a voltage sequence to scan with a constant velocity from start to stop, including acceleration and deceleration. @@ -172,7 +189,12 @@ def scan(self, start: float, stop: float, sample_count: int, sample_rate: Quanti """ v_start = self.to_volt(start) if start == stop: # todo: tolerance? - return np.ones((sample_count,)) * v_start, start, start, slice(0, sample_count) + return ( + np.ones((sample_count,)) * v_start, + start, + start, + slice(0, sample_count), + ) v_end = self.to_volt(stop) scan_speed = (v_end - v_start) / sample_count # V per sample @@ -181,9 +203,13 @@ def scan(self, start: float, stop: float, sample_count: int, sample_rate: Quanti # we start by constructing a sequence with a maximum acceleration. # This sequence may be up to 1 sample longer than needed to reach the scan speed. # This last sample is replaced by movement at a linear scan speed - a = self.maximum_acceleration / sample_rate ** 2 * np.sign(scan_speed) # V per sample² + a = ( + self.maximum_acceleration / sample_rate**2 * np.sign(scan_speed) + ) # V per sample² t_launch = np.arange(np.ceil(unitless(scan_speed / a))) # in samples - v_accel = 0.5 * a * t_launch ** 2 # last sample may have faster scan speed than needed + v_accel = ( + 0.5 * a * t_launch**2 + ) # last sample may have faster scan speed than needed if len(v_accel) > 1 and np.abs(v_accel[-1] - v_accel[-2]) > np.abs(scan_speed): v_accel[-1] = v_accel[-2] + scan_speed v_launch = v_start - v_accel[-1] - 0.5 * scan_speed # launch point @@ -200,8 +226,13 @@ def scan(self, start: float, stop: float, sample_count: int, sample_rate: Quanti return v, launch, land, slice(len(v_accel), len(v_accel) + sample_count) @staticmethod - def compute_scale(*, optical_deflection: Quantity[u.deg / u.V], galvo_to_pupil_magnification: float, - objective_magnification: float, reference_tube_lens: Quantity[u.mm]) -> Quantity[u.um / u.V]: + def compute_scale( + *, + optical_deflection: Quantity[u.deg / u.V], + galvo_to_pupil_magnification: float, + objective_magnification: float, + reference_tube_lens: Quantity[u.mm], + ) -> Quantity[u.um / u.V]: """Computes the conversion factor between voltage and displacement in the object plane. Args: @@ -224,12 +255,18 @@ def compute_scale(*, optical_deflection: Quantity[u.deg / u.V], galvo_to_pupil_m """ f_objective = reference_tube_lens / objective_magnification angle_to_displacement = f_objective / u.rad - return ((optical_deflection / galvo_to_pupil_magnification) * angle_to_displacement).to(u.um / u.V) + return ( + (optical_deflection / galvo_to_pupil_magnification) * angle_to_displacement + ).to(u.um / u.V) @staticmethod - def compute_acceleration(*, optical_deflection: Quantity[u.deg / u.V], torque_constant: Quantity[u.N * u.m / u.A], - rotor_inertia: Quantity[u.kg * u.m ** 2], - maximum_current: Quantity[u.A]) -> Quantity[u.V / u.s ** 2]: + def compute_acceleration( + *, + optical_deflection: Quantity[u.deg / u.V], + torque_constant: Quantity[u.N * u.m / u.A], + rotor_inertia: Quantity[u.kg * u.m**2], + maximum_current: Quantity[u.A], + ) -> Quantity[u.V / u.s**2]: """Computes the angular acceleration of the focus of the galvo mirror. The result is returned in the unit V / second², @@ -247,16 +284,19 @@ def compute_acceleration(*, optical_deflection: Quantity[u.deg / u.V], torque_co maximum_current (Quantity[u.A]): The maximum current that can be applied to the galvo mirror. """ - angular_acceleration = (torque_constant * maximum_current / rotor_inertia).to(u.s ** -2) * u.rad - return (angular_acceleration / optical_deflection).to(u.V / u.s ** 2) + angular_acceleration = (torque_constant * maximum_current / rotor_inertia).to( + u.s**-2 + ) * u.rad + return (angular_acceleration / optical_deflection).to(u.V / u.s**2) class TestPatternType(Enum): """Type of test pattern to use for simulation.""" - NONE = 'none' - HORIZONTAL = 'horizontal' - VERTICAL = 'vertical' - IMAGE = 'image' + + NONE = "none" + HORIZONTAL = "horizontal" + VERTICAL = "vertical" + IMAGE = "image" class ScanningMicroscope(Detector): @@ -308,19 +348,22 @@ class ScanningMicroscope(Detector): parameter can be used. """ - def __init__(self, - input: InputChannel, - y_axis: Axis, - x_axis: Axis, - sample_rate: Quantity[u.MHz], - resolution: int, - reference_zoom: float, *, - delay: Quantity[u.us] = 0.0 * u.us, - bidirectional: bool = True, - multi_threaded: bool = True, - preprocessor: Optional[callable] = None, - test_pattern: Union[TestPatternType, str] = TestPatternType.NONE, - test_image=None): + def __init__( + self, + input: InputChannel, + y_axis: Axis, + x_axis: Axis, + sample_rate: Quantity[u.MHz], + resolution: int, + reference_zoom: float, + *, + delay: Quantity[u.us] = 0.0 * u.us, + bidirectional: bool = True, + multi_threaded: bool = True, + preprocessor: Optional[callable] = None, + test_pattern: Union[TestPatternType, str] = TestPatternType.NONE, + test_image=None, + ): """ Args: resolution: number of pixels (height and width) in the full field of view. @@ -353,8 +396,12 @@ def __init__(self, self._resolution = int(resolution) self._roi_top = 0 # in pixels self._roi_left = 0 # in pixels - self._center_x = 0.5 # in relative coordinates (relative to the full field of view) - self._center_y = 0.5 # in relative coordinates (relative to the full field of view) + self._center_x = ( + 0.5 # in relative coordinates (relative to the full field of view) + ) + self._center_y = ( + 0.5 # in relative coordinates (relative to the full field of view) + ) self._delay = delay.to(u.us) self._reference_zoom = float(reference_zoom) self._zoom = 1.0 @@ -365,9 +412,9 @@ def __init__(self, self._test_pattern = TestPatternType(test_pattern) self._test_image = None if test_image is not None: - self._test_image = np.array(test_image, dtype='uint16') + self._test_image = np.array(test_image, dtype="uint16") while self._test_image.ndim > 2: - self._test_image = np.mean(self._test_image, 2).astype('uint16') + self._test_image = np.mean(self._test_image, 2).astype("uint16") self._preprocessor = preprocessor @@ -379,9 +426,13 @@ def __init__(self, # the pixel size and duration are computed dynamically # data_shape just returns self._data shape, and latency = 0.0 ms - super().__init__(data_shape=(resolution, resolution), pixel_size=None, duration=None, - latency=0.0 * u.ms, - multi_threaded=multi_threaded) + super().__init__( + data_shape=(resolution, resolution), + pixel_size=None, + duration=None, + latency=0.0 * u.ms, + multi_threaded=multi_threaded, + ) self._update() def _update(self): @@ -411,7 +462,9 @@ def _update(self): # Compute the retrace pattern for the slow axis # The scan starts at half a pixel after roi_bottom and ends half a pixel before roi_top - v_yr = self._y_axis.step(roi_bottom - 0.5 * roi_scale, roi_top + 0.5 * roi_scale, self._sample_rate) + v_yr = self._y_axis.step( + roi_bottom - 0.5 * roi_scale, roi_top + 0.5 * roi_scale, self._sample_rate + ) # Compute the scan pattern for the fast axis # The naive speed is the scan speed assuming one pixel per sample @@ -419,22 +472,33 @@ def _update(self): # (at least, without spending more time on accelerating and decelerating than the scan itself) # The user can set the scan speed relative to the maximum speed. # If this set speed is lower than naive scan speed, multiple samples are taken per pixel. - naive_speed = (self._x_axis.v_max - self._x_axis.v_min) * roi_scale * self._sample_rate - max_speed = self._x_axis.maximum_scan_speed(1.0 / actual_zoom) * self._scan_speed_factor + naive_speed = ( + (self._x_axis.v_max - self._x_axis.v_min) * roi_scale * self._sample_rate + ) + max_speed = ( + self._x_axis.maximum_scan_speed(1.0 / actual_zoom) * self._scan_speed_factor + ) if max_speed == 0.0: # this may happen if the ROI reaches to or beyond [0,1]. In this case, the mirror has no time to accelerate # TODO: implement an auto-adjust option instead of raising an error - raise ValueError("Maximum scan speed is zero. " - "This may be because the region of interest exceeds the maximum voltage range") + raise ValueError( + "Maximum scan speed is zero. " + "This may be because the region of interest exceeds the maximum voltage range" + ) self._oversampling = int(np.ceil(unitless(naive_speed / max_speed))) oversampled_width = width * self._oversampling - v_x_even, x_launch, x_land, self._mask = self._x_axis.scan(roi_left, roi_right, oversampled_width, - self._sample_rate) + v_x_even, x_launch, x_land, self._mask = self._x_axis.scan( + roi_left, roi_right, oversampled_width, self._sample_rate + ) if self._bidirectional: - v_x_odd, _, _, _ = self._x_axis.scan(roi_right, roi_left, oversampled_width, self._sample_rate) + v_x_odd, _, _, _ = self._x_axis.scan( + roi_right, roi_left, oversampled_width, self._sample_rate + ) else: - v_xr = self._x_axis.step(x_land, x_launch, self._sample_rate) # horizontal retrace + v_xr = self._x_axis.step( + x_land, x_launch, self._sample_rate + ) # horizontal retrace v_x_even = np.concatenate((v_x_even, v_xr)) v_x_odd = v_x_even @@ -444,7 +508,7 @@ def _update(self): # For bidirectional mode, the scan pattern is padded to always have an even number of scan lines # The horizontal pattern is repeated continuously, so even during the # vertical retrace. In bidirectional scan mode, th - n_rows = self._data_shape[0] + np.ceil(len(v_yr) / len(v_x_odd)).astype('int32') + n_rows = self._data_shape[0] + np.ceil(len(v_yr) / len(v_x_odd)).astype("int32") self._n_cols = len(v_x_odd) if self._bidirectional and n_rows % 2 == 1: n_rows += 1 @@ -464,8 +528,8 @@ def _update(self): # which is essential for resonant scanning. if len(v_yr) > 0: retrace = scan_pattern[0, height:, :].reshape(-1) - retrace[0:len(v_yr)] = v_yr - retrace[len(v_yr):] = v_yr[-1] + retrace[0 : len(v_yr)] = v_yr + retrace[len(v_yr) :] = v_yr[-1] self._scan_pattern = scan_pattern.reshape(2, -1) if self._test_pattern != TestPatternType.NONE: @@ -489,25 +553,39 @@ def _update(self): sample_count = self._scan_pattern.shape[1] # Configure the analog output task (two channels) - self._write_task.ao_channels.add_ao_voltage_chan(self._x_axis.channel, - min_val=self._x_axis.v_min.to_value(u.V), - max_val=self._x_axis.v_max.to_value(u.V)) - self._write_task.ao_channels.add_ao_voltage_chan(self._y_axis.channel, - min_val=self._y_axis.v_min.to_value(u.V), - max_val=self._y_axis.v_max.to_value(u.V)) - self._write_task.timing.cfg_samp_clk_timing(sample_rate, samps_per_chan=sample_count) + self._write_task.ao_channels.add_ao_voltage_chan( + self._x_axis.channel, + min_val=self._x_axis.v_min.to_value(u.V), + max_val=self._x_axis.v_max.to_value(u.V), + ) + self._write_task.ao_channels.add_ao_voltage_chan( + self._y_axis.channel, + min_val=self._y_axis.v_min.to_value(u.V), + max_val=self._y_axis.v_max.to_value(u.V), + ) + self._write_task.timing.cfg_samp_clk_timing( + sample_rate, samps_per_chan=sample_count + ) # Configure the analog input task (one channel) - self._read_task.ai_channels.add_ai_voltage_chan(self._input_channel.channel, - min_val=self._input_channel.v_min.to_value(u.V), - max_val=self._input_channel.v_max.to_value(u.V), - terminal_config=self._input_channel.terminal_configuration) - self._read_task.timing.cfg_samp_clk_timing(sample_rate, samps_per_chan=sample_count) - self._read_task.triggers.start_trigger.cfg_dig_edge_start_trig(self._write_task.triggers.start_trigger.term) + self._read_task.ai_channels.add_ai_voltage_chan( + self._input_channel.channel, + min_val=self._input_channel.v_min.to_value(u.V), + max_val=self._input_channel.v_max.to_value(u.V), + terminal_config=self._input_channel.terminal_configuration, + ) + self._read_task.timing.cfg_samp_clk_timing( + sample_rate, samps_per_chan=sample_count + ) + self._read_task.triggers.start_trigger.cfg_dig_edge_start_trig( + self._write_task.triggers.start_trigger.term + ) delay = self._delay.to_value(u.s) if delay > 0.0: self._read_task.triggers.start_trigger.delay = delay - self._read_task.triggers.start_trigger.delay_units = DigitalWidthUnits.SECONDS + self._read_task.triggers.start_trigger.delay_units = ( + DigitalWidthUnits.SECONDS + ) self._writer = AnalogMultiChannelWriter(self._write_task.out_stream) self._valid = True @@ -541,21 +619,27 @@ def _raw_to_cropped(self, raw: np.ndarray) -> np.ndarray: flips the even rows back if scanned in bidirectional mode. """ # convert data to 2-d, discard padding - cropped = raw.reshape(-1, self._n_cols)[:self._data_shape[0], self._mask] + cropped = raw.reshape(-1, self._n_cols)[: self._data_shape[0], self._mask] # down sample along fast axis if needed if self._oversampling > 1: # remove samples if not divisible by oversampling factor - cropped = cropped[:, :(cropped.shape[1] // self._oversampling) * self._oversampling] + cropped = cropped[ + :, : (cropped.shape[1] // self._oversampling) * self._oversampling + ] cropped = cropped.reshape(cropped.shape[0], -1, self._oversampling) - cropped = np.round(np.mean(cropped, 2)).astype(cropped.dtype) # todo: faster alternative? + cropped = np.round(np.mean(cropped, 2)).astype( + cropped.dtype + ) # todo: faster alternative? # Change the data type into uint16 if necessary if cropped.dtype == np.int16: # add 32768 to go from -32768-32767 to 0-65535 - cropped = cropped.view('uint16') + 0x8000 + cropped = cropped.view("uint16") + 0x8000 elif cropped.dtype != np.uint16: - raise ValueError(f'Only int16 and uint16 data types are supported at the moment, got type {cropped.dtype}.') + raise ValueError( + f"Only int16 and uint16 data types are supported at the moment, got type {cropped.dtype}." + ) if self._bidirectional: # note: requires the mask to be symmetrical cropped[1::2, :] = cropped[1::2, ::-1] @@ -569,31 +653,43 @@ def _fetch(self) -> np.ndarray: # noqa self._read_task.stop() self._write_task.stop() elif self._test_pattern == TestPatternType.HORIZONTAL: - raw = np.round(self._x_axis.to_pos(self._scan_pattern[1, :] * u.V) * 10000).astype('int16') + raw = np.round( + self._x_axis.to_pos(self._scan_pattern[1, :] * u.V) * 10000 + ).astype("int16") elif self._test_pattern == TestPatternType.VERTICAL: - raw = np.round(self._y_axis.to_pos(self._scan_pattern[0, :] * u.V) * 10000).astype('int16') + raw = np.round( + self._y_axis.to_pos(self._scan_pattern[0, :] * u.V) * 10000 + ).astype("int16") elif self._test_pattern == TestPatternType.IMAGE: if self._test_image is None: - raise ValueError('No test image was provided for the image simulation.') + raise ValueError("No test image was provided for the image simulation.") # todo: cache the test image row = np.floor( - self._y_axis.to_pos(self._scan_pattern[0, :] * u.V) * (self._test_image.shape[0] - 1)).astype( - 'int32') + self._y_axis.to_pos(self._scan_pattern[0, :] * u.V) + * (self._test_image.shape[0] - 1) + ).astype("int32") column = np.floor( - self._x_axis.to_pos(self._scan_pattern[1, :] * u.V) * (self._test_image.shape[1] - 1)).astype( - 'int32') + self._x_axis.to_pos(self._scan_pattern[1, :] * u.V) + * (self._test_image.shape[1] - 1) + ).astype("int32") raw = self._test_image[row, column] else: - raise ValueError(f"Invalid simulation option {self._test_pattern}. " - "Should be 'horizontal', 'vertical', 'image', or 'None'") + raise ValueError( + f"Invalid simulation option {self._test_pattern}. " + "Should be 'horizontal', 'vertical', 'image', or 'None'" + ) # Preprocess raw data if a preprocess function is set if self._preprocessor is None: preprocessed_raw = raw elif callable(self._preprocessor): - preprocessed_raw = self._preprocessor(data=raw, sample_rate=self._sample_rate) + preprocessed_raw = self._preprocessor( + data=raw, sample_rate=self._sample_rate + ) else: - raise TypeError(f"Invalid type for {self._preprocessor}. Should be callable or None.") + raise TypeError( + f"Invalid type for {self._preprocessor}. Should be callable or None." + ) return self._raw_to_cropped(preprocessed_raw) def close(self): @@ -622,7 +718,9 @@ def preprocessor(self): @preprocessor.setter def preprocessor(self, value: Optional[callable]): if not callable(value) and value is not None: - raise TypeError(f"Invalid type for {self._preprocessor}. Should be callable or None.") + raise TypeError( + f"Invalid type for {self._preprocessor}. Should be callable or None." + ) self._preprocessor = value @property @@ -631,8 +729,10 @@ def pixel_size(self) -> Quantity: # TODO: make extent a read-only attribute of Axis extent_y = (self._y_axis.v_max - self._y_axis.v_min) * self._y_axis.scale extent_x = (self._x_axis.v_max - self._x_axis.v_min) * self._x_axis.scale - return (Quantity(extent_y, extent_x) / ( - self._reference_zoom * self._zoom * self._resolution)).to(u.um) + return ( + Quantity(extent_y, extent_x) + / (self._reference_zoom * self._zoom * self._resolution) + ).to(u.um) @property def duration(self) -> Quantity[u.ms]: @@ -832,7 +932,7 @@ def binning(self) -> int: @binning.setter def binning(self, value: int): if value < 1: - raise ValueError('Binning value should be a positive integer') + raise ValueError("Binning value should be a positive integer") self._scale_roi(self._binning / int(value)) self._binning = int(value) diff --git a/openwfs/devices/nidaq_gain.py b/openwfs/devices/nidaq_gain.py index e2971e3..b3c6dea 100644 --- a/openwfs/devices/nidaq_gain.py +++ b/openwfs/devices/nidaq_gain.py @@ -55,7 +55,9 @@ def check_overload(self): def on_reset(self, value): if value: with ni.Task() as task: - task.do_channels.add_do_chan(self.port_do, line_grouping=LineGrouping.CHAN_FOR_ALL_LINES) + task.do_channels.add_do_chan( + self.port_do, line_grouping=LineGrouping.CHAN_FOR_ALL_LINES + ) task.write([True]) time.sleep(1) task.write([False]) @@ -84,4 +86,3 @@ def gain(self, value: Quantity[u.V]): channel.ao_min = 0 channel.ao_max = 0.9 write_task.write(self._gain.to_value(u.V)) - diff --git a/openwfs/devices/slm/context.py b/openwfs/devices/slm/context.py index b36764b..0e33cf7 100644 --- a/openwfs/devices/slm/context.py +++ b/openwfs/devices/slm/context.py @@ -3,7 +3,7 @@ import glfw -SLM = 'slm.SLM' +SLM = "slm.SLM" class Context: @@ -15,6 +15,7 @@ class Context: one thread can use OpenGL at the same time. This class holds a weak ref to the SLM object, so that the SLM object can be garbage collected. """ + _lock = threading.RLock() def __init__(self, slm): diff --git a/openwfs/devices/slm/geometry.py b/openwfs/devices/slm/geometry.py index 0506cbe..b6f73a1 100644 --- a/openwfs/devices/slm/geometry.py +++ b/openwfs/devices/slm/geometry.py @@ -110,15 +110,26 @@ def rectangle(extent: ExtentType, center: CoordinateType = (0, 0)) -> Geometry: right = center[1] + 0.5 * extent[1] bottom = center[0] + 0.5 * extent[0] - vertices = np.array(([left, top, 0.0, 0.0], [right, top, 1.0, 0.0], - [left, bottom, 0.0, 1.0], [right, bottom, 1.0, 1.0]), dtype=np.float32) + vertices = np.array( + ( + [left, top, 0.0, 0.0], + [right, top, 1.0, 0.0], + [left, bottom, 0.0, 1.0], + [right, bottom, 1.0, 1.0], + ), + dtype=np.float32, + ) indices = Geometry.compute_indices_for_grid((1, 1)) return Geometry(vertices, indices) -def circular(radii: Sequence[float], segments_per_ring: Sequence[int], edge_count: int = 256, - center: CoordinateType = (0, 0)) -> Geometry: +def circular( + radii: Sequence[float], + segments_per_ring: Sequence[int], + edge_count: int = 256, + center: CoordinateType = (0, 0), +) -> Geometry: """Creates a circular geometry with the specified extent. This geometry maps a texture to a disk or a ring. @@ -153,7 +164,8 @@ def circular(radii: Sequence[float], segments_per_ring: Sequence[int], edge_coun if len(segments_per_ring) != ring_count: raise ValueError( "The length of `radii` and `segments_per_ring` should both equal the number of rings (counting " - "the inner disk as the first ring).") + "the inner disk as the first ring)." + ) # construct coordinates of points on a circle of radius 1.0 # the start and end point coincide @@ -172,15 +184,19 @@ def circular(radii: Sequence[float], segments_per_ring: Sequence[int], edge_coun segments_inside = 0 total_segments = np.sum(segments_per_ring) for r in range(ring_count): - x_outside = x * radii[r + 1] # coordinates of the vertices at the outside of the ring + x_outside = ( + x * radii[r + 1] + ) # coordinates of the vertices at the outside of the ring y_outside = y * radii[r + 1] segments = segments_inside + segments_per_ring[r] vertices[r, 0, :, 0] = x_inside + center[1] vertices[r, 0, :, 1] = y_inside + center[0] vertices[r, 1, :, 0] = x_outside + center[1] vertices[r, 1, :, 1] = y_outside + center[0] - vertices[r, :, :, 2] = np.linspace(segments_inside, segments, edge_count + 1).reshape( - (1, -1)) / total_segments # tx + vertices[r, :, :, 2] = ( + np.linspace(segments_inside, segments, edge_count + 1).reshape((1, -1)) + / total_segments + ) # tx x_inside = x_outside y_inside = y_outside segments_inside = segments @@ -190,6 +206,9 @@ def circular(radii: Sequence[float], segments_per_ring: Sequence[int], edge_coun # construct indices for a single ring, and repeat for all rings with the appropriate offset indices = Geometry.compute_indices_for_grid((1, edge_count)).reshape((1, -1)) - indices = indices + np.arange(ring_count).reshape((-1, 1)) * vertices.shape[1] * vertices.shape[2] + indices = ( + indices + + np.arange(ring_count).reshape((-1, 1)) * vertices.shape[1] * vertices.shape[2] + ) indices[:, -1] = 0xFFFF return Geometry(vertices.reshape((-1, 4)), indices.reshape(-1)) diff --git a/openwfs/devices/slm/patch.py b/openwfs/devices/slm/patch.py index a32c14c..7c55840 100644 --- a/openwfs/devices/slm/patch.py +++ b/openwfs/devices/slm/patch.py @@ -8,18 +8,44 @@ try: import OpenGL.GL as GL - from OpenGL.GL import glGenBuffers, glBindBuffer, glBufferData, glDeleteBuffers, glEnable, glBlendFunc, \ - glBlendEquation, glDisable, glUseProgram, glBindVertexBuffer, glDrawElements, glGenFramebuffers, \ - glBindFramebuffer, glFramebufferTexture2D, glCheckFramebufferStatus, glDeleteFramebuffers, \ - glEnableVertexAttribArray, glVertexAttribFormat, glVertexAttribBinding, glEnableVertexAttribArray, \ - glPrimitiveRestartIndex, glActiveTexture, glBindTexture, glGenVertexArrays, glBindVertexArray + from OpenGL.GL import ( + glGenBuffers, + glBindBuffer, + glBufferData, + glDeleteBuffers, + glEnable, + glBlendFunc, + glBlendEquation, + glDisable, + glUseProgram, + glBindVertexBuffer, + glDrawElements, + glGenFramebuffers, + glBindFramebuffer, + glFramebufferTexture2D, + glCheckFramebufferStatus, + glDeleteFramebuffers, + glEnableVertexAttribArray, + glVertexAttribFormat, + glVertexAttribBinding, + glEnableVertexAttribArray, + glPrimitiveRestartIndex, + glActiveTexture, + glBindTexture, + glGenVertexArrays, + glBindVertexArray, + ) from OpenGL.GL import shaders except AttributeError: warnings.warn("OpenGL not found, SLM will not work") from .geometry import rectangle, Geometry -from .shaders import default_vertex_shader, default_fragment_shader, \ - post_process_fragment_shader, post_process_vertex_shader +from .shaders import ( + default_vertex_shader, + default_fragment_shader, + post_process_fragment_shader, + post_process_vertex_shader, +) from .texture import Texture from ...core import PhaseSLM @@ -27,8 +53,13 @@ class Patch(PhaseSLM): _PHASES_TEXTURE = 0 # indices of the phases texture in the _texture array - def __init__(self, slm, geometry=None, vertex_shader=default_vertex_shader, - fragment_shader=default_fragment_shader): + def __init__( + self, + slm, + geometry=None, + vertex_shader=default_vertex_shader, + fragment_shader=default_fragment_shader, + ): """ Constructs a new patch (a shape) that can be drawn on the screen. By default, the patch is a square with 'radius' 1.0 (width and height 2.0) centered at 0.0, 0.0 @@ -81,7 +112,9 @@ def _draw(self): # perform the actual drawing glBindBuffer(GL.GL_ELEMENT_ARRAY_BUFFER, self._indices) glBindVertexBuffer(0, self._vertices, 0, 16) - glDrawElements(GL.GL_TRIANGLE_STRIP, self._index_count, GL.GL_UNSIGNED_SHORT, None) + glDrawElements( + GL.GL_TRIANGLE_STRIP, self._index_count, GL.GL_UNSIGNED_SHORT, None + ) def set_phases(self, values: ArrayLike, update=True): """ @@ -123,9 +156,19 @@ def geometry(self, value: Geometry): (self._vertices, self._indices) = glGenBuffers(2) self._index_count = value.indices.size glBindBuffer(GL.GL_ARRAY_BUFFER, self._vertices) - glBufferData(GL.GL_ARRAY_BUFFER, value.vertices.size * 4, value.vertices, GL.GL_DYNAMIC_DRAW) + glBufferData( + GL.GL_ARRAY_BUFFER, + value.vertices.size * 4, + value.vertices, + GL.GL_DYNAMIC_DRAW, + ) glBindBuffer(GL.GL_ELEMENT_ARRAY_BUFFER, self._indices) - glBufferData(GL.GL_ELEMENT_ARRAY_BUFFER, value.indices.size * 2, value.indices, GL.GL_DYNAMIC_DRAW) + glBufferData( + GL.GL_ELEMENT_ARRAY_BUFFER, + value.indices.size * 2, + value.indices, + GL.GL_DYNAMIC_DRAW, + ) class FrameBufferPatch(Patch): @@ -137,22 +180,34 @@ class FrameBufferPatch(Patch): _textures: list[Texture] def __init__(self, slm, lookup_table: Sequence[int]): - super().__init__(slm, fragment_shader=post_process_fragment_shader, - vertex_shader=post_process_vertex_shader) + super().__init__( + slm, + fragment_shader=post_process_fragment_shader, + vertex_shader=post_process_vertex_shader, + ) # Create a frame buffer object to render to. The frame buffer holds a texture that is the same size as the # window. All patches are first rendered to this texture. The texture # is then processed as a whole (applying the software lookup table) and displayed on the screen. self._frame_buffer = glGenFramebuffers(1) - self.set_phases(np.zeros(self.context.slm.shape, dtype=np.float32), update=False) + self.set_phases( + np.zeros(self.context.slm.shape, dtype=np.float32), update=False + ) glBindFramebuffer(GL.GL_FRAMEBUFFER, self._frame_buffer) - glFramebufferTexture2D(GL.GL_FRAMEBUFFER, GL.GL_COLOR_ATTACHMENT0, GL.GL_TEXTURE_2D, - self._textures[Patch._PHASES_TEXTURE].handle, 0) + glFramebufferTexture2D( + GL.GL_FRAMEBUFFER, + GL.GL_COLOR_ATTACHMENT0, + GL.GL_TEXTURE_2D, + self._textures[Patch._PHASES_TEXTURE].handle, + 0, + ) if glCheckFramebufferStatus(GL.GL_FRAMEBUFFER) != GL.GL_FRAMEBUFFER_COMPLETE: raise Exception("Could not construct frame buffer") glBindFramebuffer(GL.GL_FRAMEBUFFER, 0) - self._textures.append(Texture(self.context, GL.GL_TEXTURE_1D)) # create texture for lookup table + self._textures.append( + Texture(self.context, GL.GL_TEXTURE_1D) + ) # create texture for lookup table self._lookup_table = None self.lookup_table = lookup_table self.additive_blend = False @@ -164,7 +219,7 @@ def __del__(self): @property def lookup_table(self): - """1-D array """ + """1-D array""" return self._lookup_table @lookup_table.setter @@ -195,15 +250,25 @@ class VertexArray: # Since we have a fixed vertex format, we only need to bind the VertexArray once, and not bother with # updating, binding, or even deleting it def __init__(self): - self._vertex_array = glGenVertexArrays(1) # no need to destroy explicitly, destroyed when window is destroyed + self._vertex_array = glGenVertexArrays( + 1 + ) # no need to destroy explicitly, destroyed when window is destroyed glBindVertexArray(self._vertex_array) glEnableVertexAttribArray(0) glEnableVertexAttribArray(1) - glVertexAttribFormat(0, 2, GL.GL_FLOAT, GL.GL_FALSE, 0) # first two float32 are screen coordinates - glVertexAttribFormat(1, 2, GL.GL_FLOAT, GL.GL_FALSE, 8) # second two are texture coordinates + glVertexAttribFormat( + 0, 2, GL.GL_FLOAT, GL.GL_FALSE, 0 + ) # first two float32 are screen coordinates + glVertexAttribFormat( + 1, 2, GL.GL_FLOAT, GL.GL_FALSE, 8 + ) # second two are texture coordinates glVertexAttribBinding(0, 0) # use binding index 0 for both attributes - glVertexAttribBinding(1, 0) # the attribute format can now be used with glBindVertexBuffer + glVertexAttribBinding( + 1, 0 + ) # the attribute format can now be used with glBindVertexBuffer # enable primitive restart, so that we can draw multiple triangle strips with a single draw call glEnable(GL.GL_PRIMITIVE_RESTART) - glPrimitiveRestartIndex(0xFFFF) # this is the index we use to separate individual triangle strips + glPrimitiveRestartIndex( + 0xFFFF + ) # this is the index we use to separate individual triangle strips diff --git a/openwfs/devices/slm/slm.py b/openwfs/devices/slm/slm.py index e177069..d1a4750 100644 --- a/openwfs/devices/slm/slm.py +++ b/openwfs/devices/slm/slm.py @@ -13,8 +13,19 @@ try: import OpenGL.GL as GL - from OpenGL.GL import glViewport, glClearColor, glClear, glGenBuffers, glReadBuffer, glReadPixels, glFinish, \ - glBindBuffer, glBufferData, glBindBufferBase, glBindFramebuffer + from OpenGL.GL import ( + glViewport, + glClearColor, + glClear, + glGenBuffers, + glReadBuffer, + glReadPixels, + glFinish, + glBindBuffer, + glBufferData, + glBindBufferBase, + glBindFramebuffer, + ) except AttributeError: warnings.warn("OpenGL not found, SLM will not work") from .patch import FrameBufferPatch, Patch, VertexArray @@ -31,9 +42,27 @@ class SLM(Actuator, PhaseSLM): See :numref:`section-slms` for more information. """ - __slots__ = ['_vertex_array', '_frame_buffer', '_monitor_id', '_position', '_refresh_rate', - '_transform', '_shape', '_window', '_globals', '_frame_buffer', 'patches', 'primary_patch', - '_coordinate_system', '_pixel_reader', '_phase_reader', '_field_reader', '_context', '_clones'] + + __slots__ = [ + "_vertex_array", + "_frame_buffer", + "_monitor_id", + "_position", + "_refresh_rate", + "_transform", + "_shape", + "_window", + "_globals", + "_frame_buffer", + "patches", + "primary_patch", + "_coordinate_system", + "_pixel_reader", + "_phase_reader", + "_field_reader", + "_context", + "_clones", + ] _active_slms = WeakSet() """Keep track of all active SLMs. This is done for two reasons. First, to check if we are not putting two @@ -43,10 +72,17 @@ class SLM(Actuator, PhaseSLM): WINDOWED = 0 patches: list[Patch] - def __init__(self, monitor_id: int = WINDOWED, shape: Optional[tuple[int, int]] = None, - pos: tuple[int, int] = (0, 0), refresh_rate: Optional[Quantity[u.Hz]] = None, - latency: TimeType = 2, duration: TimeType = 1, coordinate_system: str = 'short', - transform: Optional[Transform] = None): + def __init__( + self, + monitor_id: int = WINDOWED, + shape: Optional[tuple[int, int]] = None, + pos: tuple[int, int] = (0, 0), + refresh_rate: Optional[Quantity[u.Hz]] = None, + latency: TimeType = 2, + duration: TimeType = 1, + coordinate_system: str = "short", + transform: Optional[Transform] = None, + ): """ Constructs a new SLM window. @@ -86,7 +122,9 @@ def __init__(self, monitor_id: int = WINDOWED, shape: Optional[tuple[int, int]] self._position = pos (default_shape, default_rate, _) = SLM._current_mode(monitor_id) self._shape = default_shape if shape is None else shape - self._refresh_rate = default_rate if refresh_rate is None else refresh_rate.to_value(u.Hz) + self._refresh_rate = ( + default_rate if refresh_rate is None else refresh_rate.to_value(u.Hz) + ) self._frame_buffer = None self._window = None self._globals = -1 @@ -121,20 +159,32 @@ def _assert_window_available(self, monitor_id) -> None: Exception: If a full screen SLM is already present on the target monitor. """ if monitor_id == SLM.WINDOWED: - if any([slm.monitor_id == 1 for slm in SLM._active_slms if slm is not self]): + if any( + [slm.monitor_id == 1 for slm in SLM._active_slms if slm is not self] + ): raise RuntimeError( - f"Cannot create an SLM window because a full-screen SLM is already active on monitor 1") + f"Cannot create an SLM window because a full-screen SLM is already active on monitor 1" + ) else: # we cannot have multiple full screen windows on the same monitor. Also, we cannot have # a full screen window on monitor 1 if there are already windowed SLMs. - if any([slm.monitor_id == monitor_id or - (monitor_id == 1 and slm.monitor_id == SLM.WINDOWED) - for slm in SLM._active_slms if slm is not self]): - raise RuntimeError(f"Cannot create a full-screen SLM window on monitor {monitor_id} because a " - f"window is already displayed on that monitor") + if any( + [ + slm.monitor_id == monitor_id + or (monitor_id == 1 and slm.monitor_id == SLM.WINDOWED) + for slm in SLM._active_slms + if slm is not self + ] + ): + raise RuntimeError( + f"Cannot create a full-screen SLM window on monitor {monitor_id} because a " + f"window is already displayed on that monitor" + ) if monitor_id > len(glfw.get_monitors()): - raise IndexError(f"Monitor {monitor_id} not found, only {len(glfw.get_monitors())} monitor(s) " - f"are connected.") + raise IndexError( + f"Monitor {monitor_id} not found, only {len(glfw.get_monitors())} monitor(s) " + f"are connected." + ) @staticmethod def _current_mode(monitor_id: int): @@ -152,7 +202,11 @@ def _current_mode(monitor_id: int): mode = glfw.get_video_mode(monitor) shape = (mode.size[1], mode.size[0]) - return shape, mode.refresh_rate, min([mode.bits.red, mode.bits.green, mode.bits.blue]) + return ( + shape, + mode.refresh_rate, + min([mode.bits.red, mode.bits.green, mode.bits.blue]), + ) def _on_resize(self): """Updates shape and refresh rate to the actual values of the window. @@ -169,7 +223,11 @@ def _on_resize(self): """ # create a new frame buffer, re-use the old one if one was present, otherwise use a default of range(256) # re-use the lookup table if possible, otherwise create a default one ranging from 0 to 255. - old_lut = self._frame_buffer.lookup_table if self._frame_buffer is not None else range(256) + old_lut = ( + self._frame_buffer.lookup_table + if self._frame_buffer is not None + else range(256) + ) self._frame_buffer = FrameBufferPatch(self, old_lut) glViewport(0, 0, self._shape[1], self._shape[0]) # tell openGL to wait for the vertical retrace when swapping buffers (it appears need to do this @@ -180,22 +238,29 @@ def _on_resize(self): (fb_width, fb_height) = glfw.get_framebuffer_size(self._window) fb_shape = (fb_height, fb_width) if self._shape != fb_shape: - warnings.warn(f"Actual resolution {fb_shape} does not match requested resolution {self._shape}.") + warnings.warn( + f"Actual resolution {fb_shape} does not match requested resolution {self._shape}." + ) self._shape = fb_shape - (current_size, current_rate, current_bit_depth) = SLM._current_mode(self._monitor_id) + (current_size, current_rate, current_bit_depth) = SLM._current_mode( + self._monitor_id + ) # verify that the bit depth is at least 8 bit if current_bit_depth < 8: warnings.warn( f"Bit depth is less than 8 bits " - f"You may not be able to use the full phase resolution of your SLM.") + f"You may not be able to use the full phase resolution of your SLM." + ) # verify the refresh rate is correct # Then update the refresh rate to the actual value if int(self._refresh_rate) != current_rate: - warnings.warn(f"Actual refresh rate of {current_rate} Hz does not match set rate " - f"of {self._refresh_rate} Hz") + warnings.warn( + f"Actual refresh rate of {current_rate} Hz does not match set rate " + f"of {self._refresh_rate} Hz" + ) self._refresh_rate = current_rate @staticmethod @@ -208,19 +273,29 @@ def _init_glfw(): trouble if the user of our library also uses glfw for something else. """ glfw.init() - glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE) # Required on Mac. Doesn't hurt on Windows - glfw.window_hint(glfw.OPENGL_FORWARD_COMPAT, glfw.TRUE) # Required on Mac. Useless on Windows + glfw.window_hint( + glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE + ) # Required on Mac. Doesn't hurt on Windows + glfw.window_hint( + glfw.OPENGL_FORWARD_COMPAT, glfw.TRUE + ) # Required on Mac. Useless on Windows glfw.window_hint(glfw.CONTEXT_VERSION_MAJOR, 4) # request at least opengl 4.2 glfw.window_hint(glfw.CONTEXT_VERSION_MINOR, 2) glfw.window_hint(glfw.FLOATING, glfw.TRUE) # Keep window on top glfw.window_hint(glfw.DECORATED, glfw.FALSE) # Disable window border - glfw.window_hint(glfw.AUTO_ICONIFY, glfw.FALSE) # Prevent window minimization during task switch + glfw.window_hint( + glfw.AUTO_ICONIFY, glfw.FALSE + ) # Prevent window minimization during task switch glfw.window_hint(glfw.FOCUSED, glfw.FALSE) glfw.window_hint(glfw.DOUBLEBUFFER, glfw.TRUE) - glfw.window_hint(glfw.RED_BITS, 8) # require at least 8 bits per color channel (256 gray values) + glfw.window_hint( + glfw.RED_BITS, 8 + ) # require at least 8 bits per color channel (256 gray values) glfw.window_hint(glfw.GREEN_BITS, 8) glfw.window_hint(glfw.BLUE_BITS, 8) - glfw.window_hint(glfw.COCOA_RETINA_FRAMEBUFFER, glfw.FALSE) # disable retina multisampling on Mac (untested) + glfw.window_hint( + glfw.COCOA_RETINA_FRAMEBUFFER, glfw.FALSE + ) # disable retina multisampling on Mac (untested) glfw.window_hint(glfw.SAMPLES, 0) # disable multisampling def _create_window(self): @@ -232,10 +307,18 @@ def _create_window(self): shared = other._window if other is not None else None # noqa: ok to use _window SLM._active_slms.add(self) - monitor = glfw.get_monitors()[self._monitor_id - 1] if self._monitor_id != SLM.WINDOWED else None + monitor = ( + glfw.get_monitors()[self._monitor_id - 1] + if self._monitor_id != SLM.WINDOWED + else None + ) glfw.window_hint(glfw.REFRESH_RATE, int(self._refresh_rate)) - self._window = glfw.create_window(self._shape[1], self._shape[0], "OpenWFS SLM", monitor, shared) - glfw.set_input_mode(self._window, glfw.CURSOR, glfw.CURSOR_HIDDEN) # disable cursor + self._window = glfw.create_window( + self._shape[1], self._shape[0], "OpenWFS SLM", monitor, shared + ) + glfw.set_input_mode( + self._window, glfw.CURSOR, glfw.CURSOR_HIDDEN + ) # disable cursor if monitor: # full screen mode glfw.set_gamma(monitor, 1.0) else: # windowed mode @@ -300,8 +383,7 @@ def refresh_rate(self) -> Quantity[u.Hz]: @property def period(self) -> Quantity[u.ms]: - """The period of the refresh rate in milliseconds (read only). - """ + """The period of the refresh rate in milliseconds (read only).""" return (1000 / self._refresh_rate) * u.ms @property @@ -331,9 +413,15 @@ def monitor_id(self, value): monitor = glfw.get_monitors()[value - 1] if value != SLM.WINDOWED else None # move window to new monitor - glfw.set_window_monitor(self._window, monitor, self._position[1], self._position[0], self._shape[1], - self._shape[0], - int(self._refresh_rate)) + glfw.set_window_monitor( + self._window, + monitor, + self._position[1], + self._position[0], + self._shape[1], + self._shape[0], + int(self._refresh_rate), + ) self._on_resize() def __del__(self): @@ -361,7 +449,9 @@ def update(self): """ with self._context: # first draw all patches into the frame buffer - glBindFramebuffer(GL.GL_FRAMEBUFFER, self._frame_buffer._frame_buffer) # noqa - ok to access 'friend class' + glBindFramebuffer( + GL.GL_FRAMEBUFFER, self._frame_buffer._frame_buffer + ) # noqa - ok to access 'friend class' glClear(GL.GL_COLOR_BUFFER_BIT) for patch in self.patches: patch._draw() # noqa - ok to access 'friend class' @@ -373,7 +463,9 @@ def update(self): glfw.poll_events() # process window messages if len(self._clones) > 0: - self._context.__exit__(None, None, None) # release context before updating clones + self._context.__exit__( + None, None, None + ) # release context before updating clones for clone in self._clones: with clone.slm._context: # noqa self._frame_buffer._draw() # noqa - ok to access 'friend class' @@ -413,31 +505,31 @@ def duration(self, value: Quantity[u.ms]): def coordinate_system(self) -> str: """Specifies the base coordinate system that is used to map vertex coordinates to the SLM window. - Possible values are 'full', 'short' and 'long'. + Possible values are 'full', 'short' and 'long'. - 'full' means that the coordinate range (-1,-1) to (1,1) is mapped to the entire SLM window. - If the window is not square, this means that the coordinates are anisotropic. + 'full' means that the coordinate range (-1,-1) to (1,1) is mapped to the entire SLM window. + If the window is not square, this means that the coordinates are anisotropic. - 'short' and 'long' map the coordinate range (-1,-1) to (1,1) to a square. - 'short' means that the square is scaled to fill the short side of the SLM (introducing zero-padding at the - edges). + 'short' and 'long' map the coordinate range (-1,-1) to (1,1) to a square. + 'short' means that the square is scaled to fill the short side of the SLM (introducing zero-padding at the + edges). - 'long' means that the square is scaled to fill the long side of the SLM - (causing part of the coordinate range to be cropped because these coordinates correspond to points outside - the SLM window). + 'long' means that the square is scaled to fill the long side of the SLM + (causing part of the coordinate range to be cropped because these coordinates correspond to points outside + the SLM window). - For a square SLM, 'full', 'short' and 'long' are all equivalent. + For a square SLM, 'full', 'short' and 'long' are all equivalent. - In all three cases, (-1,-1) corresponds to the top-left corner of the screen, and (1,-1) to the - bottom-left corner. This convention is consistent with that used in numpy/matplotlib + In all three cases, (-1,-1) corresponds to the top-left corner of the screen, and (1,-1) to the + bottom-left corner. This convention is consistent with that used in numpy/matplotlib - To further modify the mapping system, use the `transform` property. + To further modify the mapping system, use the `transform` property. """ return self._coordinate_system @coordinate_system.setter def coordinate_system(self, value: str): - if value not in ['full', 'short', 'long']: + if value not in ["full", "short", "long"]: raise ValueError(f"Unsupported coordinate system {value}") self._coordinate_system = value self.transform = self._transform # trigger update of transform matrix on gpu @@ -466,21 +558,29 @@ def transform(self, value: Transform): self._field_reader = None # update matrix stored on the gpu - if self._coordinate_system == 'full': + if self._coordinate_system == "full": transform = self._transform else: - scale_width = (width > height) == (self._coordinate_system == 'short') + scale_width = (width > height) == (self._coordinate_system == "short") if scale_width: - root_transform = Transform(np.array(((1.0, 0.0), (0.0, height / width)))) + root_transform = Transform( + np.array(((1.0, 0.0), (0.0, height / width))) + ) else: - root_transform = Transform(np.array(((width / height, 0.0), (0.0, 1.0)))) + root_transform = Transform( + np.array(((width / height, 0.0), (0.0, 1.0))) + ) transform = self._transform @ root_transform padded = transform.opencl_matrix() with self._context: glBindBuffer(GL.GL_UNIFORM_BUFFER, self._globals) - glBufferData(GL.GL_UNIFORM_BUFFER, padded.size * 4, padded, GL.GL_STATIC_DRAW) - glBindBufferBase(GL.GL_UNIFORM_BUFFER, 1, self._globals) # connect buffer to binding point 1 + glBufferData( + GL.GL_UNIFORM_BUFFER, padded.size * 4, padded, GL.GL_STATIC_DRAW + ) + glBindBufferBase( + GL.GL_UNIFORM_BUFFER, 1, self._globals + ) # connect buffer to binding point 1 @property def lookup_table(self) -> Sequence[int]: @@ -527,8 +627,12 @@ def phases(self) -> Detector: self._phase_reader = FrameBufferReader(self) return self._phase_reader - def clone(self, monitor_id: int = WINDOWED, shape: Optional[tuple[int, int]] = None, - pos: tuple[int, int] = (0, 0)): + def clone( + self, + monitor_id: int = WINDOWED, + shape: Optional[tuple[int, int]] = None, + pos: tuple[int, int] = (0, 0), + ): """Creates a new SLM window that mirrors the content of this SLM window. This is useful for demonstration and debugging purposes. @@ -560,8 +664,13 @@ def __init__(self, slm: SLM): class FrontBufferReader(Detector): def __init__(self, slm): self._context = Context(slm) - super().__init__(data_shape=None, pixel_size=None, duration=0.0 * u.ms, latency=0.0 * u.ms, - multi_threaded=False) + super().__init__( + data_shape=None, + pixel_size=None, + duration=0.0 * u.ms, + latency=0.0 * u.ms, + multi_threaded=False, + ) @property def data_shape(self): @@ -571,7 +680,7 @@ def _fetch(self, *args, **kwargs) -> np.ndarray: with self._context: glReadBuffer(GL.GL_FRONT) shape = self.data_shape - data = np.empty(shape, dtype='uint8') + data = np.empty(shape, dtype="uint8") glReadPixels(0, 0, shape[1], shape[0], GL.GL_RED, GL.GL_UNSIGNED_BYTE, data) # flip data upside down, because the OpenGL convention is to have the origin at the bottom left, # but we want it at the top left (like in numpy) @@ -581,8 +690,13 @@ def _fetch(self, *args, **kwargs) -> np.ndarray: class FrameBufferReader(Detector): def __init__(self, slm): self._context = Context(slm) - super().__init__(data_shape=None, pixel_size=None, duration=0.0 * u.ms, latency=0.0 * u.ms, - multi_threaded=False) + super().__init__( + data_shape=None, + pixel_size=None, + duration=0.0 * u.ms, + latency=0.0 * u.ms, + multi_threaded=False, + ) @property def data_shape(self): diff --git a/openwfs/devices/slm/texture.py b/openwfs/devices/slm/texture.py index 6909abe..97c83f1 100644 --- a/openwfs/devices/slm/texture.py +++ b/openwfs/devices/slm/texture.py @@ -6,8 +6,19 @@ try: import OpenGL.GL as GL - from OpenGL.GL import glGenTextures, glBindTexture, glTexImage2D, glTexSubImage2D, glTexImage1D, glTexSubImage1D, \ - glTexParameteri, glActiveTexture, glDeleteTextures, glGetTextureImage, glPixelStorei + from OpenGL.GL import ( + glGenTextures, + glBindTexture, + glTexImage2D, + glTexSubImage2D, + glTexImage1D, + glTexSubImage1D, + glTexParameteri, + glActiveTexture, + glDeleteTextures, + glGetTextureImage, + glPixelStorei, + ) except AttributeError: warnings.warn("OpenGL not found, SLM will not work"), @@ -17,7 +28,9 @@ def __init__(self, slm, texture_type=GL.GL_TEXTURE_2D): self.context = Context(slm) self.handle = glGenTextures(1) self.type = texture_type - self.synchronized = False # self.data is not yet synchronized with texture in GPU memory + self.synchronized = ( + False # self.data is not yet synchronized with texture in GPU memory + ) self._data_shape = None # current size of the texture, to see if we need to make a new texture or # overwrite the exiting one @@ -37,22 +50,28 @@ def __del__(self): glDeleteTextures(1, [self.handle]) def _bind(self, idx): - """ Bind texture to texture unit idx. Assumes that the OpenGL context is already active.""" + """Bind texture to texture unit idx. Assumes that the OpenGL context is already active.""" glActiveTexture(GL.GL_TEXTURE0 + idx) glBindTexture(self.type, self.handle) def set_data(self, value): - """ Set texture data. + """Set texture data. The texture data is directly copied to the GPU memory, so the original data array can be modified or deleted. """ - value = np.array(value, dtype=np.float32, order='C', copy=False) + value = np.array(value, dtype=np.float32, order="C", copy=False) with self.context: glBindTexture(self.type, self.handle) - glPixelStorei(GL.GL_UNPACK_ALIGNMENT, 4) # alignment is at least four bytes since we use float32 - (internal_format, data_format, data_type) = (GL.GL_R32F, GL.GL_RED, GL.GL_FLOAT) + glPixelStorei( + GL.GL_UNPACK_ALIGNMENT, 4 + ) # alignment is at least four bytes since we use float32 + (internal_format, data_format, data_type) = ( + GL.GL_R32F, + GL.GL_RED, + GL.GL_FLOAT, + ) if self.type == GL.GL_TEXTURE_1D: # check if data has the correct dimension, convert scalars to arrays of correct dimension @@ -62,11 +81,28 @@ def set_data(self, value): raise ValueError("Data should be a 1-d array or a scalar") if value.shape != self._data_shape: # create a new texture - glTexImage1D(GL.GL_TEXTURE_1D, 0, internal_format, value.shape[0], 0, data_format, data_type, value) + glTexImage1D( + GL.GL_TEXTURE_1D, + 0, + internal_format, + value.shape[0], + 0, + data_format, + data_type, + value, + ) self._data_shape = value.shape else: # overwrite existing texture - glTexSubImage1D(GL.GL_TEXTURE_1D, 0, 0, value.shape[0], data_format, data_type, value) + glTexSubImage1D( + GL.GL_TEXTURE_1D, + 0, + 0, + value.shape[0], + data_format, + data_type, + value, + ) elif self.type == GL.GL_TEXTURE_2D: if value.ndim == 0: @@ -74,17 +110,37 @@ def set_data(self, value): elif value.ndim != 2: raise ValueError("Data should be a 2-D array or a scalar") if value.shape != self._data_shape: - glTexImage2D(GL.GL_TEXTURE_2D, 0, internal_format, value.shape[1], value.shape[0], 0, - data_format, data_type, value) + glTexImage2D( + GL.GL_TEXTURE_2D, + 0, + internal_format, + value.shape[1], + value.shape[0], + 0, + data_format, + data_type, + value, + ) self._data_shape = value.shape else: - glTexSubImage2D(GL.GL_TEXTURE_2D, 0, 0, 0, value.shape[1], value.shape[0], data_format, - data_type, value) + glTexSubImage2D( + GL.GL_TEXTURE_2D, + 0, + 0, + 0, + value.shape[1], + value.shape[0], + data_format, + data_type, + value, + ) else: raise ValueError("Texture type not supported") def get_data(self): with self.context: - data = np.empty(self._data_shape, dtype='float32') - glGetTextureImage(self.handle, 0, GL.GL_RED, GL.GL_FLOAT, data.size * 4, data) + data = np.empty(self._data_shape, dtype="float32") + glGetTextureImage( + self.handle, 0, GL.GL_RED, GL.GL_FLOAT, data.size * 4, data + ) return data diff --git a/openwfs/plot_utilities.py b/openwfs/plot_utilities.py index 27d7516..1a9a6bd 100644 --- a/openwfs/plot_utilities.py +++ b/openwfs/plot_utilities.py @@ -14,11 +14,11 @@ def imshow(data, axis=None): e0 = scale_prefix(extent[0]) e1 = scale_prefix(extent[1]) if axis is None: - plt.imshow(data, extent=(0.0, e1.value, 0.0, e0.value), cmap='gray') + plt.imshow(data, extent=(0.0, e1.value, 0.0, e0.value), cmap="gray") plt.colorbar() axis = plt.gca() else: - axis.imshow(data, extent=(0.0, e1.value, 0.0, e0.value), cmap='gray') + axis.imshow(data, extent=(0.0, e1.value, 0.0, e0.value), cmap="gray") plt.ylabel(e0.unit.to_string()) plt.xlabel(e1.unit.to_string()) plt.show(block=False) @@ -28,7 +28,7 @@ def imshow(data, axis=None): def scale_prefix(value: u.Quantity) -> u.Quantity: """Scale a quantity to the most appropriate prefix unit.""" - if value.unit.physical_type == 'length': + if value.unit.physical_type == "length": if value < 100 * u.nm: return value.to(u.nm) if value < 100 * u.um: @@ -37,7 +37,7 @@ def scale_prefix(value: u.Quantity) -> u.Quantity: return value.to(u.mm) else: return value.to(u.m) - elif value.unit.physical_type == 'time': + elif value.unit.physical_type == "time": if value < 100 * u.ns: return value.to(u.ns) if value < 100 * u.us: diff --git a/openwfs/processors/__init__.py b/openwfs/processors/__init__.py index 1c2a16a..ca18915 100644 --- a/openwfs/processors/__init__.py +++ b/openwfs/processors/__init__.py @@ -1,2 +1,9 @@ from . import processors -from .processors import CropProcessor, SingleRoi, MultipleRoi, TransformProcessor, Roi, select_roi +from .processors import ( + CropProcessor, + SingleRoi, + MultipleRoi, + TransformProcessor, + Roi, + select_roi, +) diff --git a/openwfs/processors/processors.py b/openwfs/processors/processors.py index d6babae..d270a77 100644 --- a/openwfs/processors/processors.py +++ b/openwfs/processors/processors.py @@ -17,8 +17,9 @@ class Roi: radius, mask type, and parameters specific to the mask type. """ - def __init__(self, pos, radius=0.1, mask_type: str = 'disk', waist=None, - source_shape=None): + def __init__( + self, pos, radius=0.1, mask_type: str = "disk", waist=None, source_shape=None + ): """ Initialize the Roi object. @@ -36,10 +37,16 @@ def __init__(self, pos, radius=0.1, mask_type: str = 'disk', waist=None, """ if pos is None: pos = (source_shape[0] // 2, source_shape[1] // 2) - if round(pos[0] - radius) < 0 or round(pos[1] - radius) < 0 or source_shape is not None and ( - round(pos[0] + radius) >= source_shape[0] or - round(pos[1] + radius) >= source_shape[1]): - raise ValueError('ROI does not fit inside source image') + if ( + round(pos[0] - radius) < 0 + or round(pos[1] - radius) < 0 + or source_shape is not None + and ( + round(pos[0] + radius) >= source_shape[0] + or round(pos[1] + radius) >= source_shape[1] + ) + ): + raise ValueError("ROI does not fit inside source image") self._pos = pos self._radius = radius @@ -99,7 +106,7 @@ def mask_type(self) -> str: @waist.setter def waist(self, value: str): - if value not in ['disk', 'gaussian', 'square']: + if value not in ["disk", "gaussian", "square"]: raise ValueError("mask_type must be 'disk', 'gaussian', or 'square'") self._mask_type = value self._mask = None # need to re-compute mask @@ -124,10 +131,10 @@ def apply(self, image: np.ndarray, order: float = 1.0): # for circular masks, always use an odd number of pixels so that we have a clearly # defined center. # for square masks, instead use the actual size - if self.mask_type == 'disk': + if self.mask_type == "disk": d = round(self._radius) * 2 + 1 self._mask = disk(d, r) - elif self.mask_type == 'gaussian': + elif self.mask_type == "gaussian": d = round(self._radius) * 2 + 1 self._mask = gaussian(d, self._waist) else: # square @@ -138,13 +145,15 @@ def apply(self, image: np.ndarray, order: float = 1.0): image_start = np.array(self.pos) - int(0.5 * self._mask.shape[0] - 0.5) image_cropped = image[ - image_start[0]:image_start[0] + self._mask.shape[0], - image_start[1]:image_start[1] + self._mask.shape[1]] + image_start[0] : image_start[0] + self._mask.shape[0], + image_start[1] : image_start[1] + self._mask.shape[1], + ] if image_cropped.shape != self._mask.shape: raise ValueError( f"ROI is larger than the possible area. ROI shape: {self._mask.shape}, " - + f"Cropped image shape: {image_cropped.shape}") + + f"Cropped image shape: {image_cropped.shape}" + ) if order != 1.0: image_cropped = np.power(image_cropped, order) @@ -200,8 +209,15 @@ def pixel_size(self) -> None: class SingleRoi(MultipleRoi): - def __init__(self, source, pos=None, radius=0.1, mask_type: str = 'disk', waist=0.5, - multi_threaded: bool = True): + def __init__( + self, + source, + pos=None, + radius=0.1, + mask_type: str = "disk", + waist=0.5, + multi_threaded: bool = True, + ): """ Processor that averages a signal over a single region of interest (ROI). @@ -228,8 +244,14 @@ class CropProcessor(Processor): the data is padded with 'padding_value' """ - def __init__(self, source: Detector, shape: Optional[Sequence[int]] = None, - pos: Optional[Sequence[int]] = None, padding_value=0.0, multi_threaded: bool = False): + def __init__( + self, + source: Detector, + shape: Optional[Sequence[int]] = None, + pos: Optional[Sequence[int]] = None, + padding_value=0.0, + multi_threaded: bool = False, + ): """ Args: @@ -244,7 +266,11 @@ def __init__(self, source: Detector, shape: Optional[Sequence[int]] = None, """ super().__init__(source, multi_threaded=multi_threaded) self._data_shape = tuple(shape) if shape is not None else source.data_shape - self._pos = np.array(pos) if pos is not None else np.zeros((len(self.data_shape),), dtype=int) + self._pos = ( + np.array(pos) + if pos is not None + else np.zeros((len(self.data_shape),), dtype=int) + ) self._padding_value = padding_value @property @@ -272,16 +298,19 @@ def _fetch(self, image: np.ndarray) -> np.ndarray: # noqa Returns: the out array containing the cropped image. """ - src_start = np.maximum(self._pos, 0).astype('int32') - src_end = np.minimum(self._pos + self._data_shape, image.shape).astype('int32') - dst_start = np.maximum(-self._pos, 0).astype('int32') + src_start = np.maximum(self._pos, 0).astype("int32") + src_end = np.minimum(self._pos + self._data_shape, image.shape).astype("int32") + dst_start = np.maximum(-self._pos, 0).astype("int32") dst_end = dst_start + src_end - src_start src_select = tuple( - slice(start, end) for (start, end) in zip(src_start, src_end)) + slice(start, end) for (start, end) in zip(src_start, src_end) + ) src = image.__getitem__(src_select) if any(dst_start != 0) or any(dst_end != self._data_shape): dst = np.zeros(self._data_shape) + self._padding_value - dst_select = tuple(slice(start, end) for (start, end) in zip(dst_start, dst_end)) + dst_select = tuple( + slice(start, end) for (start, end) in zip(dst_start, dst_end) + ) dst.__setitem__(dst_select, src) else: dst = src @@ -293,10 +322,17 @@ def select_roi(source: Detector, mask_type: str): """ Opens a window that allows the user to select a region of interest. """ - if mask_type not in ['disk', 'gaussian', 'square']: + if mask_type not in ["disk", "gaussian", "square"]: raise ValueError("mask_type must be 'disk', 'gaussian', or 'square'") - image = cv2.normalize(source.read(), None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U) + image = cv2.normalize( + source.read(), + None, + alpha=0, + beta=255, + norm_type=cv2.NORM_MINMAX, + dtype=cv2.CV_8U, + ) title = "Select ROI and press c to continue or ESC to cancel" cv2.namedWindow(title) cv2.imshow(title, image) @@ -312,10 +348,18 @@ def mouse_callback(event, x, y, flags, _param): elif event == cv2.EVENT_MOUSEMOVE and cv2.EVENT_FLAG_LBUTTON & flags: roi_size = np.minimum(x - roi_start[0], y - roi_start[1]) rect_image = image.copy() - if mask_type == 'square': - cv2.rectangle(rect_image, roi_start, roi_start + roi_size, (0.0, 0.0, 255.0), 2) + if mask_type == "square": + cv2.rectangle( + rect_image, roi_start, roi_start + roi_size, (0.0, 0.0, 255.0), 2 + ) else: - cv2.circle(rect_image, roi_start + roi_size // 2, abs(roi_size) // 2, (0.0, 0.0, 255.0), 2) + cv2.circle( + rect_image, + roi_start + roi_size // 2, + abs(roi_size) // 2, + (0.0, 0.0, 255.0), + 2, + ) cv2.imshow(title, rect_image) cv2.setMouseCallback(title, mouse_callback) @@ -344,19 +388,24 @@ class TransformProcessor(Processor): should match the unit of the input data after applying the transform. """ - def __init__(self, source: Detector, - transform: Transform = None, - data_shape: Optional[Sequence[int]] = None, - pixel_size: Optional[Quantity] = None, - multi_threaded: bool = True): + def __init__( + self, + source: Detector, + transform: Transform = None, + data_shape: Optional[Sequence[int]] = None, + pixel_size: Optional[Quantity] = None, + multi_threaded: bool = True, + ): """ Args: transform: Transform object that describes the transformation from the source to the target image data_shape: Shape of the output. If omitted, the shape of the input data is used. multi_threaded: Whether to perform processing in a worker thread. - """ - if (data_shape is not None and len(data_shape) != 2) or len(source.data_shape) != 2: + """ + if (data_shape is not None and len(data_shape) != 2) or len( + source.data_shape + ) != 2: raise ValueError("TransformProcessor only supports 2-D data") if transform is None: transform = Transform() @@ -364,10 +413,14 @@ def __init__(self, source: Detector, # check if input and output pixel sizes are compatible dst_unit = transform.destination_unit(source.pixel_size.unit) if pixel_size is not None and not pixel_size.unit.is_equivalent(dst_unit): - raise ValueError("Pixel size unit does not match the unit of the transformed data") + raise ValueError( + "Pixel size unit does not match the unit of the transformed data" + ) if pixel_size is None and not source.pixel_size.unit.is_equivalent(dst_unit): - raise ValueError("The transform changes the unit of the coordinates." - " An output pixel_size must be provided.") + raise ValueError( + "The transform changes the unit of the coordinates." + " An output pixel_size must be provided." + ) self.transform = transform super().__init__(source, multi_threaded=multi_threaded) @@ -390,5 +443,9 @@ def _fetch(self, source: np.ndarray) -> np.ndarray: # noqa Returns: ndarray that has been transformed TODO: Fix and add test, or remove """ - return project(source, transform=self.transform, out_shape=self.data_shape, - out_extent=self.extent) + return project( + source, + transform=self.transform, + out_shape=self.data_shape, + out_extent=self.extent, + ) diff --git a/openwfs/simulation/__init__.py b/openwfs/simulation/__init__.py index d4e122f..5f1807f 100644 --- a/openwfs/simulation/__init__.py +++ b/openwfs/simulation/__init__.py @@ -4,6 +4,13 @@ from . import transmission from .microscope import Microscope -from .mockdevices import XYStage, StaticSource, Camera, ADCProcessor, Shutter, NoiseSource +from .mockdevices import ( + XYStage, + StaticSource, + Camera, + ADCProcessor, + Shutter, + NoiseSource, +) from .slm import SLM, PhaseToField from .transmission import SimulatedWFS diff --git a/openwfs/simulation/microscope.py b/openwfs/simulation/microscope.py index c1cdfe0..93cae59 100644 --- a/openwfs/simulation/microscope.py +++ b/openwfs/simulation/microscope.py @@ -11,7 +11,14 @@ from ..plot_utilities import imshow # noqa - for debugging from ..processors import TransformProcessor from ..simulation.mockdevices import XYStage, Camera, StaticSource -from ..utilities import project, place, Transform, get_pixel_size, patterns, CoordinateType +from ..utilities import ( + project, + place, + Transform, + get_pixel_size, + patterns, + CoordinateType, +) class Microscope(Processor): @@ -31,15 +38,22 @@ class Microscope(Processor): that has the same total intensity as the source image. """ - def __init__(self, source: Union[Detector, np.ndarray], *, data_shape=None, - numerical_aperture: float = 1.0, - wavelength: Quantity[u.nm], - magnification: float = 1.0, xy_stage=None, z_stage=None, - incident_field: Union[Detector, ArrayLike, None] = None, - incident_transform: Optional[Transform] = None, - aberrations: Union[Detector, np.ndarray, None] = None, - aberration_transform: Optional[Transform] = None, - multi_threaded: bool = True): + def __init__( + self, + source: Union[Detector, np.ndarray], + *, + data_shape=None, + numerical_aperture: float = 1.0, + wavelength: Quantity[u.nm], + magnification: float = 1.0, + xy_stage=None, + z_stage=None, + incident_field: Union[Detector, ArrayLike, None] = None, + incident_transform: Optional[Transform] = None, + aberrations: Union[Detector, np.ndarray, None] = None, + aberration_transform: Optional[Transform] = None, + multi_threaded: bool = True + ): """ Args: source: 2-D image (must have `pixel_size` metadata), or @@ -93,7 +107,9 @@ def __init__(self, source: Union[Detector, np.ndarray], *, data_shape=None, if get_pixel_size(aberrations) is None: aberrations = StaticSource(aberrations) - super().__init__(source, aberrations, incident_field, multi_threaded=multi_threaded) + super().__init__( + source, aberrations, incident_field, multi_threaded=multi_threaded + ) self._magnification = magnification self._data_shape = data_shape if data_shape is not None else source.data_shape self.numerical_aperture = numerical_aperture @@ -105,8 +121,12 @@ def __init__(self, source: Union[Detector, np.ndarray], *, data_shape=None, self.z_stage = z_stage # or MockStage() self._psf = None - def _fetch(self, source: np.ndarray, aberrations: np.ndarray, # noqa - incident_field: np.ndarray) -> np.ndarray: + def _fetch( + self, + source: np.ndarray, + aberrations: np.ndarray, # noqa + incident_field: np.ndarray, + ) -> np.ndarray: """ Updates the image on the camera sensor @@ -133,7 +153,9 @@ def _fetch(self, source: np.ndarray, aberrations: np.ndarray, # noqa source_pixel_size = get_pixel_size(source) target_pixel_size = self.pixel_size / self.magnification if np.any(source_pixel_size > target_pixel_size): - warnings.warn("The resolution of the specimen image is worse than that of the output.") + warnings.warn( + "The resolution of the specimen image is worse than that of the output." + ) # Note: there seems to be a bug (feature?) in `fftconvolve` that shifts the image by one pixel # when the 'same' option is used. To compensate for this feature, @@ -166,23 +188,40 @@ def _fetch(self, source: np.ndarray, aberrations: np.ndarray, # noqa # Compute the field in the pupil plane # The aberrations and the SLM phase pattern are both mapped to the pupil plane coordinates - pupil_field = patterns.disk(pupil_shape, radius=self.numerical_aperture, extent=pupil_extent) - pupil_area = np.sum(pupil_field) # TODO (efficiency): compute area directly from radius + pupil_field = patterns.disk( + pupil_shape, radius=self.numerical_aperture, extent=pupil_extent + ) + pupil_area = np.sum( + pupil_field + ) # TODO (efficiency): compute area directly from radius # Project aberrations if aberrations is not None: # use default of 2.0 * NA for the extent of the aberration map if no pixel size is provided - aberration_extent = (2.0 * self.numerical_aperture,) * 2 if get_pixel_size(aberrations) is None else None - pupil_field = pupil_field * np.exp(1.0j * project(aberrations, - source_extent=aberration_extent, - out_extent=pupil_extent, - out_shape=pupil_shape, - transform=self.aberration_transform)) + aberration_extent = ( + (2.0 * self.numerical_aperture,) * 2 + if get_pixel_size(aberrations) is None + else None + ) + pupil_field = pupil_field * np.exp( + 1.0j + * project( + aberrations, + source_extent=aberration_extent, + out_extent=pupil_extent, + out_shape=pupil_shape, + transform=self.aberration_transform, + ) + ) # Project SLM fields if incident_field is not None: - pupil_field = pupil_field * project(incident_field, out_extent=pupil_extent, out_shape=pupil_shape, - transform=self.slm_transform) + pupil_field = pupil_field * project( + incident_field, + out_extent=pupil_extent, + out_shape=pupil_shape, + transform=self.slm_transform, + ) # Compute the point spread function # This is done by Fourier transforming the pupil field and taking the absolute value squared @@ -193,7 +232,7 @@ def _fetch(self, source: np.ndarray, aberrations: np.ndarray, # noqa psf = np.fft.ifftshift(psf) * (psf.size / pupil_area) self._psf = psf # store psf for later inspection - return fftconvolve(source, psf, 'same') + return fftconvolve(source, psf, "same") @property def magnification(self) -> float: @@ -225,10 +264,14 @@ def data_shape(self): """Returns the shape of the image in the image plane""" return self._data_shape - def get_camera(self, *, transform: Optional[Transform] = None, - data_shape: Optional[tuple[int, int]] = None, - pixel_size: Optional[CoordinateType] = None, - **kwargs) -> Detector: + def get_camera( + self, + *, + transform: Optional[Transform] = None, + data_shape: Optional[tuple[int, int]] = None, + pixel_size: Optional[CoordinateType] = None, + **kwargs + ) -> Detector: """ Returns a simulated camera that observes the microscope image. @@ -247,6 +290,8 @@ def get_camera(self, *, transform: Optional[Transform] = None, if transform is None and data_shape is None and pixel_size is None: src = self else: - src = TransformProcessor(self, data_shape=data_shape, pixel_size=pixel_size, transform=transform) + src = TransformProcessor( + self, data_shape=data_shape, pixel_size=pixel_size, transform=transform + ) return Camera(src, **kwargs) diff --git a/openwfs/simulation/mockdevices.py b/openwfs/simulation/mockdevices.py index 1f002a7..4dedeb1 100644 --- a/openwfs/simulation/mockdevices.py +++ b/openwfs/simulation/mockdevices.py @@ -15,8 +15,15 @@ class StaticSource(Detector): Detector that returns pre-set data. Also simulates latency and measurement duration. """ - def __init__(self, data: np.ndarray, pixel_size: Optional[ExtentType] = None, extent: Optional[ExtentType] = None, - latency: Quantity[u.ms] = 0 * u.ms, duration: Quantity[u.ms] = 0 * u.ms, multi_threaded: bool = None): + def __init__( + self, + data: np.ndarray, + pixel_size: Optional[ExtentType] = None, + extent: Optional[ExtentType] = None, + latency: Quantity[u.ms] = 0 * u.ms, + duration: Quantity[u.ms] = 0 * u.ms, + multi_threaded: bool = None, + ): """ Initializes the MockSource TODO: factor out the latency and duration into a separate class? @@ -35,15 +42,24 @@ def __init__(self, data: np.ndarray, pixel_size: Optional[ExtentType] = None, ex else: pixel_size = get_pixel_size(data) - if pixel_size is not None and (np.isscalar(pixel_size) or pixel_size.size == 1) and data.ndim > 1: + if ( + pixel_size is not None + and (np.isscalar(pixel_size) or pixel_size.size == 1) + and data.ndim > 1 + ): pixel_size = pixel_size.repeat(data.ndim) if multi_threaded is None: multi_threaded = latency > 0 * u.ms or duration > 0 * u.ms self._data = data - super().__init__(data_shape=data.shape, pixel_size=pixel_size, latency=latency, duration=duration, - multi_threaded=multi_threaded) + super().__init__( + data_shape=data.shape, + pixel_size=pixel_size, + latency=latency, + duration=duration, + multi_threaded=multi_threaded, + ) def _fetch(self) -> np.ndarray: # noqa total_time_s = self.latency.to_value(u.s) + self.duration.to_value(u.s) @@ -72,22 +88,34 @@ def data(self, value): class NoiseSource(Detector): - def __init__(self, noise_type: str, *, data_shape: tuple[int, ...], pixel_size: Quantity, multi_threaded=True, - generator=None, - **kwargs): + def __init__( + self, + noise_type: str, + *, + data_shape: tuple[int, ...], + pixel_size: Quantity, + multi_threaded=True, + generator=None, + **kwargs, + ): self._noise_type = noise_type self._noise_arguments = kwargs self._rng = generator if generator is not None else np.random.default_rng() - super().__init__(data_shape=data_shape, pixel_size=pixel_size, latency=0 * u.ms, duration=0 * u.ms, - multi_threaded=multi_threaded) + super().__init__( + data_shape=data_shape, + pixel_size=pixel_size, + latency=0 * u.ms, + duration=0 * u.ms, + multi_threaded=multi_threaded, + ) def _fetch(self) -> np.ndarray: # noqa - if self._noise_type == 'uniform': + if self._noise_type == "uniform": return self._rng.uniform(**self._noise_arguments, size=self.data_shape) - elif self._noise_type == 'gaussian': + elif self._noise_type == "gaussian": return self._rng.normal(**self._noise_arguments, size=self.data_shape) else: - raise ValueError(f'Unknown noise type: {self._noise_type}') + raise ValueError(f"Unknown noise type: {self._noise_type}") @Detector.data_shape.setter def data_shape(self, value): @@ -100,9 +128,16 @@ class ADCProcessor(Processor): At the moment, only positive input and output values are supported. """ - def __init__(self, source: Detector, analog_max: float = 0.0, digital_max: int = 0xFFFF, - shot_noise: bool = False, gaussian_noise_std: float = 0.0, multi_threaded: bool = True, - generator=None): + def __init__( + self, + source: Detector, + analog_max: float = 0.0, + digital_max: int = 0xFFFF, + shot_noise: bool = False, + gaussian_noise_std: float = 0.0, + multi_threaded: bool = True, + generator=None, + ): """ Initializes the ADCProcessor class, which mimics an analog-digital converter. @@ -140,7 +175,9 @@ def _fetch(self, data) -> np.ndarray: # noqa if self.analog_max == 0.0: # auto scaling max_value = np.max(data) if max_value > 0.0: - data = data * (self.digital_max / max_value) # auto-scale to maximum value + data = data * ( + self.digital_max / max_value + ) # auto-scale to maximum value else: data = data * (self.digital_max / self.analog_max) @@ -148,9 +185,11 @@ def _fetch(self, data) -> np.ndarray: # noqa data = self._rng.poisson(data) if self._gaussian_noise_std > 0.0: - data = data + self._rng.normal(scale=self._gaussian_noise_std, size=data.shape) + data = data + self._rng.normal( + scale=self._gaussian_noise_std, size=data.shape + ) - return np.clip(np.rint(data), 0, self.digital_max).astype('uint16') + return np.clip(np.rint(data), 0, self.digital_max).astype("uint16") @property def analog_max(self) -> Optional[float]: @@ -165,7 +204,7 @@ def analog_max(self) -> Optional[float]: @analog_max.setter def analog_max(self, value): if value < 0.0: - raise ValueError('analog_max cannot be negative') + raise ValueError("analog_max cannot be negative") self._analog_max = value @property @@ -179,7 +218,7 @@ def digital_max(self) -> int: @digital_max.setter def digital_max(self, value): if value < 0 or value > 0xFFFF: - raise ValueError('digital_max must be between 0 and 0xFFFF') + raise ValueError("digital_max must be between 0 and 0xFFFF") self._digital_max = int(value) @property @@ -209,8 +248,13 @@ class Camera(ADCProcessor): Conversion to uint16 is implemented in the ADCProcessor base class. """ - def __init__(self, source: Detector, shape: Optional[Sequence[int]] = None, - pos: Optional[Sequence[int]] = None, **kwargs): + def __init__( + self, + source: Detector, + shape: Optional[Sequence[int]] = None, + pos: Optional[Sequence[int]] = None, + **kwargs, + ): """ Args: source (Detector): The source detector to be wrapped. diff --git a/openwfs/simulation/slm.py b/openwfs/simulation/slm.py index a5905a5..eaa04b2 100644 --- a/openwfs/simulation/slm.py +++ b/openwfs/simulation/slm.py @@ -17,8 +17,12 @@ class PhaseToField(Processor): Computes `amplitude * (exp(1j * phase) + non_modulated_field_fraction)` """ - def __init__(self, slm_phases: Detector, field_amplitude: ArrayLike = 1.0, - non_modulated_field_fraction: float = 0.0): + def __init__( + self, + slm_phases: Detector, + field_amplitude: ArrayLike = 1.0, + non_modulated_field_fraction: float = 0.0, + ): """ Args: slm_phases: The `Detector` that returns the phases of the slm pixels. @@ -34,7 +38,9 @@ def _fetch(self, slm_phases: np.ndarray) -> np.ndarray: # noqa Updates the complex field output of the SLM. The output field is the sum of the modulated field and the non-modulated field. """ - return self.modulated_field_amplitude * (np.exp(1j * slm_phases) + self.non_modulated_field) + return self.modulated_field_amplitude * ( + np.exp(1j * slm_phases) + self.non_modulated_field + ) class _SLMTiming(Detector): @@ -45,15 +51,22 @@ class _SLMTiming(Detector): the refresh rate, or the conversion of gray values to phases. """ - def __init__(self, - shape: tuple[int, ...], - update_latency: Quantity[u.ms] = 0.0 * u.ms, - update_duration: Quantity[u.ms] = 0.0 * u.ms): + def __init__( + self, + shape: tuple[int, ...], + update_latency: Quantity[u.ms] = 0.0 * u.ms, + update_duration: Quantity[u.ms] = 0.0 * u.ms, + ): if len(shape) != 2: raise ValueError("Shape of the SLM should be 2-dimensional.") - super().__init__(data_shape=shape, pixel_size=Quantity(2.0 / np.min(shape)), latency=0 * u.ms, - duration=0 * u.ms, multi_threaded=False) + super().__init__( + data_shape=shape, + pixel_size=Quantity(2.0 / np.min(shape)), + latency=0 * u.ms, + duration=0 * u.ms, + multi_threaded=False, + ) self.update_latency = update_latency self.update_duration = update_duration @@ -139,20 +152,29 @@ class SLM(PhaseSLM, Actuator): A mock version of a phase-only spatial light modulator. Some properties are available to simulate physical phenomena such as imperfect phase response, and front reflections (which cause non-modulated light). """ - __slots__ = ('_hardware_fields', '_hardware_phases', '_hardware_timing', '_back_buffer', - 'refresh_rate', '_first_update_ns', '_lookup_table') - - def __init__(self, - shape: tuple[int, ...], - latency: Quantity[u.ms] = 0.0 * u.ms, - duration: Quantity[u.ms] = 0.0 * u.ms, - update_latency: Quantity[u.ms] = 0.0 * u.ms, - update_duration: Quantity[u.ms] = 0.0 * u.ms, - refresh_rate: Quantity[u.Hz] = 0 * u.Hz, - field_amplitude: Union[np.ndarray, float, None] = 1.0, - non_modulated_field_fraction: float = 0.0, - phase_response: Optional[np.ndarray] = None, - ): + + __slots__ = ( + "_hardware_fields", + "_hardware_phases", + "_hardware_timing", + "_back_buffer", + "refresh_rate", + "_first_update_ns", + "_lookup_table", + ) + + def __init__( + self, + shape: tuple[int, ...], + latency: Quantity[u.ms] = 0.0 * u.ms, + duration: Quantity[u.ms] = 0.0 * u.ms, + update_latency: Quantity[u.ms] = 0.0 * u.ms, + update_duration: Quantity[u.ms] = 0.0 * u.ms, + refresh_rate: Quantity[u.Hz] = 0 * u.Hz, + field_amplitude: Union[np.ndarray, float, None] = 1.0, + non_modulated_field_fraction: float = 0.0, + phase_response: Optional[np.ndarray] = None, + ): """ Args: @@ -169,16 +191,20 @@ def __init__(self, Choose a value different from `duration` to simulate incorrect timing. refresh_rate: Simulated refresh rate. Affects the timing of the `update` method, since this will wait until the next vertical retrace. Keep at 0 to disable this feature. - """ + """ super().__init__(latency=latency, duration=duration) self.refresh_rate = refresh_rate # Simulates transferring frames to the SLM self._hardware_timing = _SLMTiming(shape, update_latency, update_duration) - self._hardware_phases = _SLMPhaseResponse(self._hardware_timing, - phase_response) # Simulates reading the phase from the SLM - self._hardware_fields = PhaseToField(self._hardware_phases, field_amplitude, - non_modulated_field_fraction) # Simulates reading the field from the SLM - self._lookup_table = None # index = input phase (scaled to -> [0, 255]), value = grey value + self._hardware_phases = _SLMPhaseResponse( + self._hardware_timing, phase_response + ) # Simulates reading the phase from the SLM + self._hardware_fields = PhaseToField( + self._hardware_phases, field_amplitude, non_modulated_field_fraction + ) # Simulates reading the field from the SLM + self._lookup_table = ( + None # index = input phase (scaled to -> [0, 255]), value = grey value + ) self._first_update_ns = time.time_ns() self._back_buffer = np.zeros(shape, dtype=np.float32) @@ -188,8 +214,12 @@ def update(self): self._start() # wait for detectors to finish if self.refresh_rate > 0: # wait for the vertical retrace - time_in_frames = unitless((time.time_ns() - self._first_update_ns) * u.ns * self.refresh_rate) - time_to_next_frame = (np.ceil(time_in_frames) - time_in_frames) / self.refresh_rate + time_in_frames = unitless( + (time.time_ns() - self._first_update_ns) * u.ns * self.refresh_rate + ) + time_to_next_frame = ( + np.ceil(time_in_frames) - time_in_frames + ) / self.refresh_rate time.sleep(time_to_next_frame.tovalue(u.s)) # update the start time (this is also done in the actual SLM) self._start() @@ -205,7 +235,9 @@ def update(self): if self._lookup_table is None: grey_values = (256 * tx).astype(np.uint8) else: - lookup_index = (self._lookup_table.shape[0] * tx).astype(np.uint8) # index into lookup table + lookup_index = (self._lookup_table.shape[0] * tx).astype( + np.uint8 + ) # index into lookup table grey_values = self._lookup_table[lookup_index] self._hardware_timing.send(grey_values) @@ -242,8 +274,12 @@ def set_phases(self, values: ArrayLike, update=True): # no docstring, use documentation from base class # Copy the phase image to the back buffer, scaling it as necessary - project(np.atleast_2d(values).astype('float32'), out=self._back_buffer, source_extent=(2.0, 2.0), - out_extent=(2.0, 2.0)) + project( + np.atleast_2d(values).astype("float32"), + out=self._back_buffer, + source_extent=(2.0, 2.0), + out_extent=(2.0, 2.0), + ) if update: self.update() diff --git a/openwfs/simulation/transmission.py b/openwfs/simulation/transmission.py index 88e5a85..6691fa2 100644 --- a/openwfs/simulation/transmission.py +++ b/openwfs/simulation/transmission.py @@ -18,8 +18,15 @@ class SimulatedWFS(Processor): For a more advanced (but slower) simulation, use `Microscope` """ - def __init__(self, *, t: Optional[np.ndarray] = None, aberrations: Optional[np.ndarray] = None, slm=None, - multi_threaded=True, beam_amplitude: ScalarType = 1.0): + def __init__( + self, + *, + t: Optional[np.ndarray] = None, + aberrations: Optional[np.ndarray] = None, + slm=None, + multi_threaded=True, + beam_amplitude: ScalarType = 1.0 + ): """ Initializes the optical system with specified aberrations and optionally a Gaussian beam profile. @@ -43,7 +50,12 @@ def __init__(self, *, t: Optional[np.ndarray] = None, aberrations: Optional[np.n """ # transmission matrix (normalized so that the maximum transmission is 1) - self._t = t if t is not None else np.exp(1.0j * aberrations) / (aberrations.shape[0] * aberrations.shape[1]) + self._t = ( + t + if t is not None + else np.exp(1.0j * aberrations) + / (aberrations.shape[0] * aberrations.shape[1]) + ) self.slm = slm if slm is not None else SLM(self._t.shape[0:2]) super().__init__(self.slm.field, multi_threaded=multi_threaded) diff --git a/openwfs/utilities/__init__.py b/openwfs/utilities/__init__.py index 3765cba..3136f24 100644 --- a/openwfs/utilities/__init__.py +++ b/openwfs/utilities/__init__.py @@ -1,5 +1,15 @@ from . import patterns from . import utilities from .patterns import coordinate_range, disk, gaussian, tilt -from .utilities import ExtentType, CoordinateType, unitless, get_pixel_size, \ - set_pixel_size, Transform, project, place, set_extent, get_extent +from .utilities import ( + ExtentType, + CoordinateType, + unitless, + get_pixel_size, + set_pixel_size, + Transform, + project, + place, + set_extent, + get_extent, +) diff --git a/openwfs/utilities/patterns.py b/openwfs/utilities/patterns.py index 1dc9d47..818c352 100644 --- a/openwfs/utilities/patterns.py +++ b/openwfs/utilities/patterns.py @@ -42,8 +42,9 @@ """ -def coordinate_range(shape: ShapeType, extent: ExtentType, offset: Optional[CoordinateType] = None) -> (Quantity, - Quantity): +def coordinate_range( + shape: ShapeType, extent: ExtentType, offset: Optional[CoordinateType] = None +) -> (Quantity, Quantity): """ Returns coordinate vectors for the two coordinates (y and x). @@ -72,8 +73,10 @@ def c_range(res, ex, cx): dx = ex / res return np.arange(res) * dx + (0.5 * dx - 0.5 * ex + cx) - return (c_range(shape[0], extent[0], offset[0]).reshape((-1, 1)), - c_range(shape[1], extent[1], offset[1]).reshape((1, -1))) + return ( + c_range(shape[0], extent[0], offset[0]).reshape((-1, 1)), + c_range(shape[1], extent[1], offset[1]).reshape((1, -1)), + ) def r2_range(shape: ShapeType, extent: ExtentType): @@ -81,10 +84,15 @@ def r2_range(shape: ShapeType, extent: ExtentType): Equivalent to computing cx^2 + cy^2 """ c0, c1 = coordinate_range(shape, extent) - return c0 ** 2 + c1 ** 2 + return c0**2 + c1**2 -def tilt(shape: ShapeType, g: ExtentType, extent: ExtentType = (2.0, 2.0), phase_offset: float = 0.0): +def tilt( + shape: ShapeType, + g: ExtentType, + extent: ExtentType = (2.0, 2.0), + phase_offset: float = 0.0, +): """Constructs a linear gradient pattern φ=2 g·r Args: @@ -115,11 +123,17 @@ def lens(shape: ShapeType, f: ScalarType, wavelength: ScalarType, extent: Extent extent(ExtentType): physical extent of the SLM, same units as `f` and `wavelength` """ r_sqr = r2_range(shape, extent) - return unitless((f - np.sqrt(f ** 2 + r_sqr)) * (2 * np.pi / wavelength)) + return unitless((f - np.sqrt(f**2 + r_sqr)) * (2 * np.pi / wavelength)) -def propagation(shape: ShapeType, distance: ScalarType, numerical_aperture: ScalarType, - refractive_index: ScalarType, wavelength: ScalarType, extent: ExtentType = (2.0, 2.0)): +def propagation( + shape: ShapeType, + distance: ScalarType, + numerical_aperture: ScalarType, + refractive_index: ScalarType, + wavelength: ScalarType, + extent: ExtentType = (2.0, 2.0), +): """Computes a wavefront that corresponds to digitally propagating the field in the object plane. k_z = sqrt(n² k_0²-k_x²-k_y²) @@ -139,7 +153,9 @@ def propagation(shape: ShapeType, distance: ScalarType, numerical_aperture: Scal # convert pupil coordinates to absolute k_x, k_y coordinates k_0 = 2.0 * np.pi / wavelength extent_k = Quantity(extent) * numerical_aperture * k_0 - k_z = np.sqrt(np.maximum((refractive_index * k_0) ** 2 - r2_range(shape, extent_k), 0.0)) + k_z = np.sqrt( + np.maximum((refractive_index * k_0) ** 2 - r2_range(shape, extent_k), 0.0) + ) return unitless(k_z * distance) @@ -153,11 +169,15 @@ def disk(shape: ShapeType, radius: ScalarType = 1.0, extent: ExtentType = (2.0, radius (ScalarType): radius of the disk, should have the same unit as `extent`. extent: see module documentation """ - return 1.0 * (r2_range(shape, extent) < radius ** 2) + return 1.0 * (r2_range(shape, extent) < radius**2) -def gaussian(shape: ShapeType, waist: ScalarType, - truncation_radius: ScalarType = None, extent: ExtentType = (2.0, 2.0)): +def gaussian( + shape: ShapeType, + waist: ScalarType, + truncation_radius: ScalarType = None, + extent: ExtentType = (2.0, 2.0), +): """Constructs an image of a centered Gaussian `waist`, `extent` and the optional `truncation_radius` should all have the same unit. @@ -172,7 +192,7 @@ def gaussian(shape: ShapeType, waist: ScalarType, """ r_sqr = r2_range(shape, extent) - w2inv = -1.0 / waist ** 2 + w2inv = -1.0 / waist**2 gauss = np.exp(unitless(r_sqr * w2inv)) if truncation_radius is not None: gauss = gauss * disk(shape, truncation_radius, extent=extent) diff --git a/openwfs/utilities/utilities.py b/openwfs/utilities/utilities.py index 1b7b5c9..3572e46 100644 --- a/openwfs/utilities/utilities.py +++ b/openwfs/utilities/utilities.py @@ -89,16 +89,25 @@ class Transform: """ - def __init__(self, transform: Optional[TransformType] = None, - source_origin: Optional[CoordinateType] = None, - destination_origin: Optional[CoordinateType] = None): + def __init__( + self, + transform: Optional[TransformType] = None, + source_origin: Optional[CoordinateType] = None, + destination_origin: Optional[CoordinateType] = None, + ): self.transform = Quantity(transform if transform is not None else np.eye(2)) - self.source_origin = Quantity(source_origin) if source_origin is not None else None - self.destination_origin = Quantity(destination_origin) if destination_origin is not None else None + self.source_origin = ( + Quantity(source_origin) if source_origin is not None else None + ) + self.destination_origin = ( + Quantity(destination_origin) if destination_origin is not None else None + ) if source_origin is not None: - self.destination_unit(self.source_origin.unit) # check if the units are consistent + self.destination_unit( + self.source_origin.unit + ) # check if the units are consistent def destination_unit(self, src_unit: u.Unit) -> u.Unit: """Computes the unit of the output of the transformation, given the unit of the input. @@ -107,20 +116,28 @@ def destination_unit(self, src_unit: u.Unit) -> u.Unit: ValueError: If src_unit does not match the unit of the source_origin (if specified) or if dst_unit does not match the unit of the destination_origin (if specified). """ - if self.source_origin is not None and not self.source_origin.unit.is_equivalent(src_unit): + if ( + self.source_origin is not None + and not self.source_origin.unit.is_equivalent(src_unit) + ): raise ValueError("src_unit must match the units of source_origin.") dst_unit = (self.transform[0, 0] * src_unit).unit - if self.destination_origin is not None and not self.destination_origin.unit.is_equivalent(dst_unit): + if ( + self.destination_origin is not None + and not self.destination_origin.unit.is_equivalent(dst_unit) + ): raise ValueError("dst_unit must match the units of destination_origin.") return dst_unit - def cv2_matrix(self, - source_shape: Sequence[int], - source_pixel_size: CoordinateType, - destination_shape: Sequence[int], - destination_pixel_size: CoordinateType) -> np.ndarray: + def cv2_matrix( + self, + source_shape: Sequence[int], + source_pixel_size: CoordinateType, + destination_shape: Sequence[int], + destination_pixel_size: CoordinateType, + ) -> np.ndarray: """Returns the transformation matrix in the format used by cv2.warpAffine.""" # correct the origin. OpenCV uses the _center_ of the top-left corner as the origin @@ -133,30 +150,42 @@ def cv2_matrix(self, if self.source_origin is not None: source_origin += self.source_origin - destination_origin = 0.5 * (np.array(destination_shape) - 1.0) * destination_pixel_size + destination_origin = ( + 0.5 * (np.array(destination_shape) - 1.0) * destination_pixel_size + ) if self.destination_origin is not None: destination_origin += self.destination_origin - centered_transform = Transform(transform=self.transform, - source_origin=source_origin, - destination_origin=destination_origin) + centered_transform = Transform( + transform=self.transform, + source_origin=source_origin, + destination_origin=destination_origin, + ) # then convert the transform to a matrix, using the specified pixel sizes - transform_matrix = centered_transform.to_matrix(source_pixel_size=source_pixel_size, - destination_pixel_size=destination_pixel_size) + transform_matrix = centered_transform.to_matrix( + source_pixel_size=source_pixel_size, + destination_pixel_size=destination_pixel_size, + ) # finally, convert the matrix to the format used by cv2.warpAffine by swapping x and y columns and rows transform_matrix = transform_matrix[[1, 0], :] transform_matrix = transform_matrix[:, [1, 0, 2]] return transform_matrix - def to_matrix(self, source_pixel_size: CoordinateType, destination_pixel_size: CoordinateType) -> np.ndarray: + def to_matrix( + self, source_pixel_size: CoordinateType, destination_pixel_size: CoordinateType + ) -> np.ndarray: matrix = np.zeros((2, 3)) - matrix[0:2, 0:2] = unitless(self.transform * source_pixel_size / destination_pixel_size) + matrix[0:2, 0:2] = unitless( + self.transform * source_pixel_size / destination_pixel_size + ) if self.destination_origin is not None: matrix[0:2, 2] = unitless(self.destination_origin / destination_pixel_size) if self.source_origin is not None: - matrix[0:2, 2] -= unitless((self.transform @ self.source_origin) / destination_pixel_size) + matrix[0:2, 2] -= unitless( + (self.transform @ self.source_origin) / destination_pixel_size + ) return matrix def opencl_matrix(self) -> np.ndarray: @@ -167,9 +196,15 @@ def opencl_matrix(self) -> np.ndarray: # to construct the homogeneous transformation matrix # convert to opencl format: swap x and y columns (note: the rows were # already swapped in the construction of t2), and flip the sign of the y-axis. - transform = np.eye(3, 4, dtype='float32', order='C') - transform[0, 0:3] = matrix[1, [1, 0, 2],] - transform[1, 0:3] = -matrix[0, [1, 0, 2],] + transform = np.eye(3, 4, dtype="float32", order="C") + transform[0, 0:3] = matrix[ + 1, + [1, 0, 2], + ] + transform[1, 0:3] = -matrix[ + 0, + [1, 0, 2], + ] return transform @staticmethod @@ -188,7 +223,8 @@ def __matmul__(self, other): def apply(self, vector: CoordinateType) -> CoordinateType: """Applies the transformation to a column vector. - If `vector` is a 2-D array, applies the transformation to each column of `vector` individually.""" + If `vector` is a 2-D array, applies the transformation to each column of `vector` individually. + """ if self.source_origin is not None: vector = vector - self.source_origin vector = self.transform @ vector @@ -198,7 +234,8 @@ def apply(self, vector: CoordinateType) -> CoordinateType: def inverse(self): """Compute the inverse transformation, - such that the composition of the transformation and its inverse is the identity.""" + such that the composition of the transformation and its inverse is the identity. + """ # invert the transform matrix if self.transform is not None: @@ -207,9 +244,13 @@ def inverse(self): transform = None # swap source and destination origins - return Transform(transform, source_origin=self.destination_origin, destination_origin=self.source_origin) + return Transform( + transform, + source_origin=self.destination_origin, + destination_origin=self.source_origin, + ) - def compose(self, other: 'Transform'): + def compose(self, other: "Transform"): """Compose two transformations. Args: @@ -220,7 +261,11 @@ def compose(self, other: 'Transform'): """ transform = self.transform @ other.transform source_origin = other.source_origin - destination_origin = self.apply(other.destination_origin) if other.destination_origin is not None else None + destination_origin = ( + self.apply(other.destination_origin) + if other.destination_origin is not None + else None + ) return Transform(transform, source_origin, destination_origin) def _standard_input(self) -> Quantity: @@ -232,8 +277,13 @@ def identity(cls): return Transform() -def place(out_shape: tuple[int, ...], out_pixel_size: Quantity, source: np.ndarray, offset: Optional[Quantity] = None, - out: Optional[np.ndarray] = None): +def place( + out_shape: tuple[int, ...], + out_pixel_size: Quantity, + source: np.ndarray, + offset: Optional[Quantity] = None, + out: Optional[np.ndarray] = None, +): """Takes a source array and places it in an otherwise empty array of specified shape and pixel size. The source array must have a pixel_size property (see set_pixel_size). @@ -251,16 +301,20 @@ def place(out_shape: tuple[int, ...], out_pixel_size: Quantity, source: np.ndarr """ out_extent = out_pixel_size * np.array(out_shape) transform = Transform(destination_origin=offset) - return project(source, out_extent=out_extent, out_shape=out_shape, transform=transform, out=out) + return project( + source, out_extent=out_extent, out_shape=out_shape, transform=transform, out=out + ) def project( - source: np.ndarray, *, - source_extent: Optional[ExtentType] = None, - transform: Optional[Transform] = None, - out: Optional[np.ndarray] = None, - out_extent: Optional[ExtentType] = None, - out_shape: Optional[tuple[int, ...]] = None) -> np.ndarray: + source: np.ndarray, + *, + source_extent: Optional[ExtentType] = None, + transform: Optional[Transform] = None, + out: Optional[np.ndarray] = None, + out_extent: Optional[ExtentType] = None, + out_shape: Optional[tuple[int, ...]] = None +) -> np.ndarray: """Projects the input image onto an array with specified shape and resolution. The input image is scaled so that the pixel sizes match those of the output, @@ -281,7 +335,9 @@ def project( transform = transform if transform is not None else Transform() if out is not None: if out_shape is not None and out_shape != out.shape: - raise ValueError("out_shape and out.shape must match. Note that out_shape may be omitted") + raise ValueError( + "out_shape and out.shape must match. Note that out_shape may be omitted" + ) if out.dtype != source.dtype: raise ValueError("out and source must have the same dtype") out_shape = out.shape @@ -289,7 +345,9 @@ def project( if out_shape is None: raise ValueError("Either out_shape or out must be specified") if out_extent is None: - raise ValueError("Either out_extent or the pixel_size metadata of out must be specified") + raise ValueError( + "Either out_extent or the pixel_size metadata of out must be specified" + ) source_extent = source_extent if source_extent is not None else get_extent(source) source_ps = source_extent / np.array(source.shape) out_ps = out_extent / np.array(out_shape) @@ -301,17 +359,38 @@ def project( if out is None: out = np.zeros(out_shape, dtype=source.dtype) # real part - out.real = cv2.warpAffine(source.real, t, out_size, flags=cv2.INTER_NEAREST, - borderMode=cv2.BORDER_CONSTANT, borderValue=(0.0,)) + out.real = cv2.warpAffine( + source.real, + t, + out_size, + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=(0.0,), + ) # imaginary part - out.imag = cv2.warpAffine(source.imag, t, out_size, flags=cv2.INTER_NEAREST, - borderMode=cv2.BORDER_CONSTANT, borderValue=(0.0,)) + out.imag = cv2.warpAffine( + source.imag, + t, + out_size, + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=(0.0,), + ) else: - dst = cv2.warpAffine(source, t, out_size, dst=out, flags=cv2.INTER_NEAREST, - borderMode=cv2.BORDER_CONSTANT, borderValue=(0.0,)) + dst = cv2.warpAffine( + source, + t, + out_size, + dst=out, + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=(0.0,), + ) if out is not None and out is not dst: - raise ValueError("OpenCV did not use the specified output array. This should not happen.") + raise ValueError( + "OpenCV did not use the specified output array. This should not happen." + ) out = dst return set_pixel_size(out, out_ps) @@ -339,7 +418,7 @@ def set_pixel_size(data: ArrayLike, pixel_size: Optional[Quantity]) -> np.ndarra if pixel_size is not None and pixel_size.size == 1: pixel_size = pixel_size * np.ones(data.ndim) - data.dtype = np.dtype(data.dtype, metadata={'pixel_size': pixel_size}) + data.dtype = np.dtype(data.dtype, metadata={"pixel_size": pixel_size}) return data @@ -364,7 +443,7 @@ def get_pixel_size(data: np.ndarray) -> Optional[Quantity]: metadata = data.dtype.metadata if metadata is None: return None - return data.dtype.metadata.get('pixel_size', None) + return data.dtype.metadata.get("pixel_size", None) def get_extent(data: np.ndarray) -> Quantity: diff --git a/tests/test_algorithms_troubleshoot.py b/tests/test_algorithms_troubleshoot.py index 0a76681..8822a1f 100644 --- a/tests/test_algorithms_troubleshoot.py +++ b/tests/test_algorithms_troubleshoot.py @@ -4,9 +4,16 @@ from .test_simulation import phase_response_test_function, lookup_table_test_function from ..openwfs.algorithms import StepwiseSequential -from ..openwfs.algorithms.troubleshoot import cnr, signal_std, find_pixel_shift, \ - field_correlation, frame_correlation, pearson_correlation, \ - measure_modulated_light, measure_modulated_light_dual_phase_stepping +from ..openwfs.algorithms.troubleshoot import ( + cnr, + signal_std, + find_pixel_shift, + field_correlation, + frame_correlation, + pearson_correlation, + measure_modulated_light, + measure_modulated_light_dual_phase_stepping, +) from ..openwfs.processors import SingleRoi from ..openwfs.simulation import SimulatedWFS, StaticSource, SLM, Microscope @@ -18,8 +25,12 @@ def test_signal_std(): a = np.random.rand(400, 400) b = np.random.rand(400, 400) assert signal_std(a, a) < 1e-6 # Test noise only - assert np.abs(signal_std(a + b, b) - a.std()) < 0.005 # Test signal+uncorrelated noise - assert np.abs(signal_std(a + a, a) - np.sqrt(3) * a.std()) < 0.005 # Test signal+correlated noise + assert ( + np.abs(signal_std(a + b, b) - a.std()) < 0.005 + ) # Test signal+uncorrelated noise + assert ( + np.abs(signal_std(a + a, a) - np.sqrt(3) * a.std()) < 0.005 + ) # Test signal+correlated noise def test_cnr(): @@ -30,8 +41,12 @@ def test_cnr(): b = np.random.randn(800, 800) cnr_gt = 3.0 # Ground Truth assert cnr(a, a) < 1e-6 # Test noise only - assert np.abs(cnr(cnr_gt * a + b, b) - cnr_gt) < 0.01 # Test signal+uncorrelated noise - assert np.abs(cnr(cnr_gt * a + a, a) - np.sqrt((cnr_gt + 1) ** 2 - 1)) < 0.01 # Test signal+correlated noise + assert ( + np.abs(cnr(cnr_gt * a + b, b) - cnr_gt) < 0.01 + ) # Test signal+uncorrelated noise + assert ( + np.abs(cnr(cnr_gt * a + a, a) - np.sqrt((cnr_gt + 1) ** 2 - 1)) < 0.01 + ) # Test signal+correlated noise def test_find_pixel_shift(): @@ -80,8 +95,12 @@ def test_field_correlation(): assert field_correlation(a, a) == 1.0 # Self-correlation assert field_correlation(2 * a, a) == 1.0 # Invariant under scalar-multiplication assert field_correlation(a, b) == 0.0 # Orthogonal arrays - assert np.abs(field_correlation(a + b, b) - np.sqrt(0.5)) < 1e-10 # Self+orthogonal array - assert np.abs(field_correlation(b, c) - np.conj(field_correlation(c, b))) < 1e-10 # Arguments swapped + assert ( + np.abs(field_correlation(a + b, b) - np.sqrt(0.5)) < 1e-10 + ) # Self+orthogonal array + assert ( + np.abs(field_correlation(b, c) - np.conj(field_correlation(c, b))) < 1e-10 + ) # Arguments swapped def test_frame_correlation(): @@ -152,12 +171,16 @@ def test_pearson_correlation_noise_compensated(): assert np.isclose(noise1.var(), noise2.var(), atol=2e-3) assert np.isclose(corr_AA, 1, atol=2e-3) assert np.isclose(corr_AB, 0, atol=2e-3) - A_spearman = 1 / np.sqrt((1 + noise1.var() / A1.var()) * (1 + noise2.var() / A2.var())) + A_spearman = 1 / np.sqrt( + (1 + noise1.var() / A1.var()) * (1 + noise2.var() / A2.var()) + ) assert np.isclose(corr_AA_with_noise, A_spearman, atol=2e-3) -@pytest.mark.parametrize("n_y, n_x, phase_steps, b, c, gamma", - [(11, 9, 8, -0.05, 1.5, 0.8), (4, 4, 10, -0.05, 1.5, 0.8)]) +@pytest.mark.parametrize( + "n_y, n_x, phase_steps, b, c, gamma", + [(11, 9, 8, -0.05, 1.5, 0.8), (4, 4, 10, -0.05, 1.5, 0.8)], +) def test_fidelity_phase_calibration_ssa_noise_free(n_y, n_x, phase_steps, b, c, gamma): """ Test computing phase calibration fidelity factor, with the SSA algorithm. Noise-free scenarios. @@ -165,7 +188,9 @@ def test_fidelity_phase_calibration_ssa_noise_free(n_y, n_x, phase_steps, b, c, # Perfect SLM, noise-free aberrations = np.random.uniform(0.0, 2 * np.pi, (n_y, n_x)) sim = SimulatedWFS(aberrations=aberrations) - alg = StepwiseSequential(feedback=sim, slm=sim.slm, n_x=n_x, n_y=n_y, phase_steps=phase_steps) + alg = StepwiseSequential( + feedback=sim, slm=sim.slm, n_x=n_x, n_y=n_y, phase_steps=phase_steps + ) result = alg.execute() assert result.fidelity_calibration > 0.99 @@ -181,8 +206,12 @@ def test_fidelity_phase_calibration_ssa_noise_free(n_y, n_x, phase_steps, b, c, assert result.fidelity_calibration > 0.99 -@pytest.mark.parametrize("n_y, n_x, phase_steps, gaussian_noise_std", [(4, 4, 10, 0.2), (6, 6, 12, 1.0)]) -def test_fidelity_phase_calibration_ssa_with_noise(n_y, n_x, phase_steps, gaussian_noise_std): +@pytest.mark.parametrize( + "n_y, n_x, phase_steps, gaussian_noise_std", [(4, 4, 10, 0.2), (6, 6, 12, 1.0)] +) +def test_fidelity_phase_calibration_ssa_with_noise( + n_y, n_x, phase_steps, gaussian_noise_std +): """ Test estimation of phase calibration fidelity factor, with the SSA algorithm. With noise. """ @@ -197,26 +226,39 @@ def test_fidelity_phase_calibration_ssa_with_noise(n_y, n_x, phase_steps, gaussi # SLM, simulation, camera, ROI detector slm = SLM(shape=(80, 80)) - sim = Microscope(source=src, incident_field=slm.field, magnification=1, - numerical_aperture=numerical_aperture, - aberrations=aberration, wavelength=800 * u.nm) + sim = Microscope( + source=src, + incident_field=slm.field, + magnification=1, + numerical_aperture=numerical_aperture, + aberrations=aberration, + wavelength=800 * u.nm, + ) cam = sim.get_camera(analog_max=1e4, gaussian_noise_std=gaussian_noise_std) roi_detector = SingleRoi(cam, radius=0) # Only measure that specific point # Define and run WFS algorithm - alg = StepwiseSequential(feedback=roi_detector, slm=slm, n_x=n_x, n_y=n_y, phase_steps=phase_steps) + alg = StepwiseSequential( + feedback=roi_detector, slm=slm, n_x=n_x, n_y=n_y, phase_steps=phase_steps + ) result_good = alg.execute() assert result_good.fidelity_calibration > 0.9 # SLM with incorrect phase response linear_phase = np.arange(0, 2 * np.pi, 2 * np.pi / 256) - slm.phase_response = phase_response_test_function(linear_phase, b=0.05, c=0.6, gamma=1.5) + slm.phase_response = phase_response_test_function( + linear_phase, b=0.05, c=0.6, gamma=1.5 + ) result_good = alg.execute() assert result_good.fidelity_calibration < 0.9 -@pytest.mark.parametrize("num_blocks, phase_steps, expected_fid, atol", [(10, 8, 1, 1e-6)]) -def test_measure_modulated_light_dual_phase_stepping_noise_free(num_blocks, phase_steps, expected_fid, atol): +@pytest.mark.parametrize( + "num_blocks, phase_steps, expected_fid, atol", [(10, 8, 1, 1e-6)] +) +def test_measure_modulated_light_dual_phase_stepping_noise_free( + num_blocks, phase_steps, expected_fid, atol +): """Test fidelity estimation due to amount of modulated light. Noise-free.""" # Perfect SLM, noise-free aberrations = np.random.uniform(0.0, 2 * np.pi, (20, 20)) @@ -224,12 +266,18 @@ def test_measure_modulated_light_dual_phase_stepping_noise_free(num_blocks, phas # Measure the amount of modulated light (no non-modulated light present) fidelity_modulated = measure_modulated_light_dual_phase_stepping( - slm=sim.slm, feedback=sim, phase_steps=phase_steps, num_blocks=num_blocks) + slm=sim.slm, feedback=sim, phase_steps=phase_steps, num_blocks=num_blocks + ) assert np.isclose(fidelity_modulated, expected_fid, atol=atol) -@pytest.mark.parametrize("num_blocks, phase_steps, gaussian_noise_std, atol", [(10, 6, 0.0, 1e-6), (6, 8, 2.0, 1e-3)]) -def test_measure_modulated_light_dual_phase_stepping_with_noise(num_blocks, phase_steps, gaussian_noise_std, atol): +@pytest.mark.parametrize( + "num_blocks, phase_steps, gaussian_noise_std, atol", + [(10, 6, 0.0, 1e-6), (6, 8, 2.0, 1e-3)], +) +def test_measure_modulated_light_dual_phase_stepping_with_noise( + num_blocks, phase_steps, gaussian_noise_std, atol +): """Test fidelity estimation due to amount of modulated light. Can test with noise.""" # === Define mock hardware, perfect SLM === # Aberration and image source @@ -239,38 +287,55 @@ def test_measure_modulated_light_dual_phase_stepping_with_noise(num_blocks, phas # SLM, simulation, camera, ROI detector slm = SLM(shape=(100, 100)) - sim = Microscope(source=src, incident_field=slm.field, magnification=1, numerical_aperture=1.0, - wavelength=800 * u.nm) + sim = Microscope( + source=src, + incident_field=slm.field, + magnification=1, + numerical_aperture=1.0, + wavelength=800 * u.nm, + ) cam = sim.get_camera(analog_max=1e4, gaussian_noise_std=gaussian_noise_std) roi_detector = SingleRoi(cam, radius=0) # Only measure that specific point # Measure the amount of modulated light (no non-modulated light present) fidelity_modulated = measure_modulated_light_dual_phase_stepping( - slm=slm, feedback=roi_detector, phase_steps=phase_steps, num_blocks=num_blocks) + slm=slm, feedback=roi_detector, phase_steps=phase_steps, num_blocks=num_blocks + ) assert np.isclose(fidelity_modulated, 1, atol=atol) @pytest.mark.parametrize( - "phase_steps, modulated_field_amplitude, non_modulated_field", [(6, 1.0, 0.0), (8, 0.5, 0.5), (8, 1.0, 0.25)]) -def test_measure_modulated_light_noise_free(phase_steps, modulated_field_amplitude, non_modulated_field): + "phase_steps, modulated_field_amplitude, non_modulated_field", + [(6, 1.0, 0.0), (8, 0.5, 0.5), (8, 1.0, 0.25)], +) +def test_measure_modulated_light_noise_free( + phase_steps, modulated_field_amplitude, non_modulated_field +): """Test fidelity estimation due to amount of modulated light. Noise-free.""" # Perfect SLM, noise-free aberrations = np.random.uniform(0.0, 2 * np.pi, (20, 20)) - slm = SLM(aberrations.shape, field_amplitude=modulated_field_amplitude, - non_modulated_field_fraction=non_modulated_field) + slm = SLM( + aberrations.shape, + field_amplitude=modulated_field_amplitude, + non_modulated_field_fraction=non_modulated_field, + ) sim = SimulatedWFS(aberrations=aberrations, slm=slm) # Measure the amount of modulated light (no non-modulated light present) - fidelity_modulated = measure_modulated_light(slm=sim.slm, feedback=sim, phase_steps=phase_steps) - expected_fid = 1.0 / (1.0 + non_modulated_field ** 2) + fidelity_modulated = measure_modulated_light( + slm=sim.slm, feedback=sim, phase_steps=phase_steps + ) + expected_fid = 1.0 / (1.0 + non_modulated_field**2) assert np.isclose(fidelity_modulated, expected_fid, rtol=0.1) @pytest.mark.parametrize( "phase_steps, gaussian_noise_std, modulated_field_amplitude, non_modulated_field", - [(8, 0.0, 0.5, 0.4), (6, 0.0, 1.0, 0.0), (12, 2.0, 1.0, 0.25)]) + [(8, 0.0, 0.5, 0.4), (6, 0.0, 1.0, 0.0), (12, 2.0, 1.0, 0.25)], +) def test_measure_modulated_light_dual_phase_stepping_with_noise( - phase_steps, gaussian_noise_std, modulated_field_amplitude, non_modulated_field): + phase_steps, gaussian_noise_std, modulated_field_amplitude, non_modulated_field +): """Test fidelity estimation due to amount of modulated light. Can test with noise.""" # === Define mock hardware, perfect SLM === # Aberration and image source @@ -279,14 +344,18 @@ def test_measure_modulated_light_dual_phase_stepping_with_noise( src = StaticSource(img, 200 * u.nm) # SLM, simulation, camera, ROI detector - slm = SLM(shape=(100, 100), - field_amplitude=modulated_field_amplitude, - non_modulated_field_fraction=non_modulated_field) + slm = SLM( + shape=(100, 100), + field_amplitude=modulated_field_amplitude, + non_modulated_field_fraction=non_modulated_field, + ) sim = Microscope(source=src, incident_field=slm.field, wavelength=800 * u.nm) cam = sim.get_camera(analog_max=1e3, gaussian_noise_std=gaussian_noise_std) roi_detector = SingleRoi(cam, radius=0) # Only measure that specific point # Measure the amount of modulated light (no non-modulated light present) - expected_fid = 1.0 / (1.0 + non_modulated_field ** 2) - fidelity_modulated = measure_modulated_light(slm=slm, feedback=roi_detector, phase_steps=phase_steps) + expected_fid = 1.0 / (1.0 + non_modulated_field**2) + fidelity_modulated = measure_modulated_light( + slm=slm, feedback=roi_detector, phase_steps=phase_steps + ) assert np.isclose(fidelity_modulated, expected_fid, rtol=0.1) diff --git a/tests/test_camera.py b/tests/test_camera.py index d953491..f635c90 100644 --- a/tests/test_camera.py +++ b/tests/test_camera.py @@ -1,7 +1,9 @@ import pytest -pytest.importorskip('harvesters', - reason='harvesters is required for the Camera module, install with pip install harvesters') +pytest.importorskip( + "harvesters", + reason="harvesters is required for the Camera module, install with pip install harvesters", +) from ..openwfs.devices import Camera @@ -36,8 +38,16 @@ def test_roi(camera, binning, top, left): # take care that the size will be a multiple of the increment, # and that setting the binning will round this number down camera.binning = binning - expected_width = (original_shape[1] // binning) // camera._nodes.Width.inc * camera._nodes.Width.inc - expected_height = (original_shape[0] // binning) // camera._nodes.Height.inc * camera._nodes.Height.inc + expected_width = ( + (original_shape[1] // binning) + // camera._nodes.Width.inc + * camera._nodes.Width.inc + ) + expected_height = ( + (original_shape[0] // binning) + // camera._nodes.Height.inc + * camera._nodes.Height.inc + ) assert camera.data_shape == (expected_height, expected_width) # check if setting the ROI works diff --git a/tests/test_core.py b/tests/test_core.py index 847d032..89502d8 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -75,15 +75,17 @@ def test_timing_detector(caplog, duration): assert np.allclose(f0.result(), image0) t5 = time.time_ns() - assert np.allclose(t1 - t0, 0.0, atol=0.1E9) - assert np.allclose(t2 - t1, duration.to_value(u.ns), atol=0.1E9) - assert np.allclose(t3 - t2, 0.0, atol=0.1E9) - assert np.allclose(t4 - t3, duration.to_value(u.ns), atol=0.1E9) - assert np.allclose(t5 - t4, 0.0, atol=0.1E9) + assert np.allclose(t1 - t0, 0.0, atol=0.1e9) + assert np.allclose(t2 - t1, duration.to_value(u.ns), atol=0.1e9) + assert np.allclose(t3 - t2, 0.0, atol=0.1e9) + assert np.allclose(t4 - t3, duration.to_value(u.ns), atol=0.1e9) + assert np.allclose(t5 - t4, 0.0, atol=0.1e9) def test_noise_detector(): - source = NoiseSource('uniform', data_shape=(10, 11, 20), low=-1.0, high=1.0, pixel_size=4 * u.um) + source = NoiseSource( + "uniform", data_shape=(10, 11, 20), low=-1.0, high=1.0, pixel_size=4 * u.um + ) data = source.read() assert data.shape == (10, 11, 20) assert np.min(data) >= -1.0 @@ -98,18 +100,31 @@ def test_noise_detector(): def test_mock_slm(): slm = SLM((4, 4)) slm.set_phases(0.5) - assert np.allclose(slm.pixels.read(), round(0.5 * 256 / (2 * np.pi)), atol=0.5 / 256) + assert np.allclose( + slm.pixels.read(), round(0.5 * 256 / (2 * np.pi)), atol=0.5 / 256 + ) discretized_phase = slm.phases.read() assert np.allclose(discretized_phase, 0.5, atol=1.1 * np.pi / 256) - assert np.allclose(slm.field.read(), np.exp(1j * discretized_phase[0, 0]), rtol=2 / 256) + assert np.allclose( + slm.field.read(), np.exp(1j * discretized_phase[0, 0]), rtol=2 / 256 + ) slm.set_phases(np.array(((0.1, 0.2), (0.3, 0.4))), update=False) - assert np.allclose(slm.phases.read(), 0.5, atol=1.1 * np.pi / 256) # slm.update() not yet called, so should be 0.5 + assert np.allclose( + slm.phases.read(), 0.5, atol=1.1 * np.pi / 256 + ) # slm.update() not yet called, so should be 0.5 slm.update() - assert np.allclose(slm.phases.read(), np.array(( - (0.1, 0.1, 0.2, 0.2), - (0.1, 0.1, 0.2, 0.2), - (0.3, 0.3, 0.4, 0.4), - (0.3, 0.3, 0.4, 0.4))), atol=1.1 * np.pi / 256) + assert np.allclose( + slm.phases.read(), + np.array( + ( + (0.1, 0.1, 0.2, 0.2), + (0.1, 0.1, 0.2, 0.2), + (0.3, 0.3, 0.4, 0.4), + (0.3, 0.3, 0.4, 0.4), + ) + ), + atol=1.1 * np.pi / 256, + ) def test_crop(): @@ -160,6 +175,7 @@ def test_crop_1d(): assert c3.shape == cropped.data_shape assert np.all(c3 == data[4:6]) + # TODO: translate the tests below. # They should test the SingleROI processor, checking if the returned averaged value is correct. # diff --git a/tests/test_processors.py b/tests/test_processors.py index de74867..4a847a4 100644 --- a/tests/test_processors.py +++ b/tests/test_processors.py @@ -6,19 +6,19 @@ import astropy.units as u -@pytest.mark.skip(reason="This is an interactive test: skip by default. TODO: actually test if the roi was " - "selected correctly.") +@pytest.mark.skip( + reason="This is an interactive test: skip by default. TODO: actually test if the roi was " + "selected correctly." +) def test_croppers(): img = sk.data.camera() src = StaticSource(img, 50 * u.nm) - roi = select_roi(src, 'disk') - assert roi.mask_type == 'disk' + roi = select_roi(src, "disk") + assert roi.mask_type == "disk" def test_single_roi_simple_case(): - data = np.array([[1, 2, 3], - [4, 5, 6], - [7, 8, 9]]) + data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) pixel_size = 1 * np.ones(2) mock_source = StaticSource(data, pixel_size=pixel_size) roi_processor = SingleRoi(mock_source, radius=np.sqrt(2)) @@ -29,8 +29,9 @@ def test_single_roi_simple_case(): print("Mask:", roi_processor._rois[()]._mask) expected_value = np.mean(data[0:3, 0:3]) # Assuming this is how the ROI is defined - assert np.isclose(result, - expected_value), f"ROI average value is incorrect. Expected: {expected_value}, Got: {result}" + assert np.isclose( + result, expected_value + ), f"ROI average value is incorrect. Expected: {expected_value}, Got: {result}" def create_mock_source_with_data(): @@ -38,36 +39,41 @@ def create_mock_source_with_data(): return StaticSource(data, pixel_size=1 * u.um) -@pytest.mark.parametrize("x, y, radius, expected_avg", [ - (2, 2, 1, 12), # Center ROI in 5x5 matrix - (0, 0, 0, 0) # Top-left corner ROI in 5x5 matrix -]) +@pytest.mark.parametrize( + "x, y, radius, expected_avg", + [ + (2, 2, 1, 12), # Center ROI in 5x5 matrix + (0, 0, 0, 0), # Top-left corner ROI in 5x5 matrix + ], +) def test_single_roi(x, y, radius, expected_avg): mock_source = create_mock_source_with_data() roi_processor = SingleRoi(mock_source, (y, x), radius) roi_processor.trigger() result = roi_processor.read() - assert np.isclose(result, expected_avg), f"ROI average value is incorrect. Expected: {expected_avg}, Got: {result}" + assert np.isclose( + result, expected_avg + ), f"ROI average value is incorrect. Expected: {expected_avg}, Got: {result}" def test_multiple_roi_simple_case(): - data = np.array([[1, 2, 3], - [4, 5, 6], - [7, 8, 9]]) + data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) pixel_size = 1 * np.ones(2) mock_source = StaticSource(data, pixel_size=pixel_size) - rois = [Roi((1, 1), radius=0), - Roi((2, 2), radius=0), - Roi((1, 1), radius=1), - Roi((0, 1), radius=0) - ] + rois = [ + Roi((1, 1), radius=0), + Roi((2, 2), radius=0), + Roi((1, 1), radius=1), + Roi((0, 1), radius=0), + ] roi_processor = MultipleRoi(mock_source, rois=rois) roi_processor.trigger() result = roi_processor.read() expected_values = [5, 9, 5, 2] - assert all(np.isclose(r, e) for r, e in zip(result, expected_values)), \ - f"ROI average values are incorrect. Expected: {expected_values}, Got: {result}" + assert all( + np.isclose(r, e) for r, e in zip(result, expected_values) + ), f"ROI average values are incorrect. Expected: {expected_values}, Got: {result}" diff --git a/tests/test_scanning_microscope.py b/tests/test_scanning_microscope.py index f6a8c45..fe0335a 100644 --- a/tests/test_scanning_microscope.py +++ b/tests/test_scanning_microscope.py @@ -2,8 +2,10 @@ import numpy as np import pytest -pytest.importorskip('nidaqmx', - reason='nidaqmx is required for the ScanningMicroscope module, install with pip install nidaqmx') +pytest.importorskip( + "nidaqmx", + reason="nidaqmx is required for the ScanningMicroscope module, install with pip install nidaqmx", +) from ..openwfs.devices import ScanningMicroscope, Axis from ..openwfs.devices.galvo_scanner import InputChannel @@ -13,12 +15,18 @@ @pytest.mark.parametrize("start, stop", [(0.0, 1.0), (1.0, 0.0)]) def test_scan_axis(start, stop): """Tests if the Axis class generates the correct voltage sequences for stepping and scanning.""" - maximum_acceleration = 1 * u.V / u.ms ** 2 + maximum_acceleration = 1 * u.V / u.ms**2 scale = 440 * u.um / u.V v_min = -1.0 * u.V v_max = 2.0 * u.V - a = Axis(channel='Dev4/ao0', v_min=v_min, v_max=v_max, maximum_acceleration=maximum_acceleration, scale=scale) - assert a.channel == 'Dev4/ao0' + a = Axis( + channel="Dev4/ao0", + v_min=v_min, + v_max=v_max, + maximum_acceleration=maximum_acceleration, + scale=scale, + ) + assert a.channel == "Dev4/ao0" assert a.v_min == v_min assert a.v_max == v_max assert a.maximum_acceleration == maximum_acceleration @@ -36,19 +44,22 @@ def test_scan_axis(start, stop): assert np.isclose(step[-1], 2.0 * u.V if start == 0.0 else -1.0 * u.V) assert np.all(step >= v_min) assert np.all(step <= v_max) - acceleration = np.diff(np.diff(step)) * sample_rate ** 2 + acceleration = np.diff(np.diff(step)) * sample_rate**2 assert np.all(np.abs(acceleration) <= maximum_acceleration * 1.01) center = 0.5 * (start + stop) amplitude = 0.5 * (stop - start) # test clipping - assert np.allclose(step, a.step(center - 1.1 * amplitude, center + 1.1 * amplitude, sample_rate)) + assert np.allclose( + step, a.step(center - 1.1 * amplitude, center + 1.1 * amplitude, sample_rate) + ) # test scan. Note that we cannot use the full scan range because we need # some time to accelerate / decelerate sample_count = 10000 - scan, launch, land, linear_region = a.scan(center - 0.8 * amplitude, center + 0.8 * amplitude, sample_count, - sample_rate) + scan, launch, land, linear_region = a.scan( + center - 0.8 * amplitude, center + 0.8 * amplitude, sample_count, sample_rate + ) half_pixel = 0.8 * amplitude / sample_count # plt.plot(scan) # plt.show() @@ -56,48 +67,76 @@ def test_scan_axis(start, stop): assert linear_region.start == len(scan) - linear_region.stop assert np.isclose(scan[0], a.to_volt(launch)) assert np.isclose(scan[-1], a.to_volt(land)) - assert np.isclose(scan[linear_region.start], a.to_volt(center - 0.8 * amplitude + half_pixel)) - assert np.isclose(scan[linear_region.stop - 1], a.to_volt(center + 0.8 * amplitude - half_pixel)) + assert np.isclose( + scan[linear_region.start], a.to_volt(center - 0.8 * amplitude + half_pixel) + ) + assert np.isclose( + scan[linear_region.stop - 1], a.to_volt(center + 0.8 * amplitude - half_pixel) + ) speed = np.diff(scan[linear_region]) assert np.allclose(speed, speed[0]) # speed should be constant - acceleration = np.diff(np.diff(scan)) * sample_rate ** 2 + acceleration = np.diff(np.diff(scan)) * sample_rate**2 assert np.all(np.abs(acceleration) <= maximum_acceleration * 1.01) def make_scanner(bidirectional, direction, reference_zoom): scale = 440 * u.um / u.V sample_rate = 0.5 * u.MHz - input_channel = InputChannel(channel='Dev4/ai0', v_min=-1.0 * u.V, v_max=1.0 * u.V) - y_axis = Axis(channel='Dev4/ao0', v_min=-2.0 * u.V, v_max=2.0 * u.V, maximum_acceleration=10 * u.V / u.ms ** 2, - scale=scale) - x_axis = Axis(channel='Dev4/ao1', v_min=-2.0 * u.V, v_max=2.0 * u.V, maximum_acceleration=10 * u.V / u.ms ** 2, - scale=scale) - return ScanningMicroscope(bidirectional=bidirectional, sample_rate=sample_rate, resolution=1024, - input=input_channel, y_axis=y_axis, x_axis=x_axis, - test_pattern=direction, reference_zoom=reference_zoom) - - -@pytest.mark.parametrize("direction", ['horizontal', 'vertical']) + input_channel = InputChannel(channel="Dev4/ai0", v_min=-1.0 * u.V, v_max=1.0 * u.V) + y_axis = Axis( + channel="Dev4/ao0", + v_min=-2.0 * u.V, + v_max=2.0 * u.V, + maximum_acceleration=10 * u.V / u.ms**2, + scale=scale, + ) + x_axis = Axis( + channel="Dev4/ao1", + v_min=-2.0 * u.V, + v_max=2.0 * u.V, + maximum_acceleration=10 * u.V / u.ms**2, + scale=scale, + ) + return ScanningMicroscope( + bidirectional=bidirectional, + sample_rate=sample_rate, + resolution=1024, + input=input_channel, + y_axis=y_axis, + x_axis=x_axis, + test_pattern=direction, + reference_zoom=reference_zoom, + ) + + +@pytest.mark.parametrize("direction", ["horizontal", "vertical"]) @pytest.mark.parametrize("bidirectional", [False, True]) def test_scan_pattern(direction, bidirectional): """A unit test for scanning patterns.""" reference_zoom = 1.2 scanner = make_scanner(bidirectional, direction, reference_zoom) - assert np.allclose(scanner.extent, scanner._x_axis.scale * 4.0 * u.V / reference_zoom) + assert np.allclose( + scanner.extent, scanner._x_axis.scale * 4.0 * u.V / reference_zoom + ) # plt.imshow(scanner.read()) # plt.show() # check if returned pattern is correct - (y, x) = coordinate_range((scanner._resolution, scanner._resolution), - 10000 / reference_zoom, offset=(5000, 5000)) - full = scanner.read().astype('float32') - 0x8000 + (y, x) = coordinate_range( + (scanner._resolution, scanner._resolution), + 10000 / reference_zoom, + offset=(5000, 5000), + ) + full = scanner.read().astype("float32") - 0x8000 pixel_size = full[1, 1] - full[0, 0] - if direction == 'horizontal': + if direction == "horizontal": assert np.allclose(full, full[0, :]) # all rows should be the same - assert np.allclose(x, full, atol=0.2 * pixel_size) # some rounding due to quantization + assert np.allclose( + x, full, atol=0.2 * pixel_size + ) # some rounding due to quantization else: # all columns should be the same (note we need to keep the last dimension for correct broadcasting) assert np.allclose(full, full[:, 0:1]) @@ -119,15 +158,17 @@ def test_scan_pattern(direction, bidirectional): assert scanner.height == height assert scanner.data_shape == (height, width) - roi = scanner.read().astype('float32') - 0x8000 - assert np.allclose(full[top:(top + height), left:(left + width)], roi, atol=0.2 * pixel_size) + roi = scanner.read().astype("float32") - 0x8000 + assert np.allclose( + full[top : (top + height), left : (left + width)], roi, atol=0.2 * pixel_size + ) @pytest.mark.parametrize("bidirectional", [False, True]) def test_park_beam(bidirectional): """A unit test for parking the beam of a DAQ scanner.""" reference_zoom = 1.2 - scanner = make_scanner(bidirectional, 'horizontal', reference_zoom) + scanner = make_scanner(bidirectional, "horizontal", reference_zoom) # Park beam horizontally scanner.top = 3 @@ -138,7 +179,9 @@ def test_park_beam(bidirectional): img = scanner.read() assert img.shape == (2, 1) voltages = scanner._scan_pattern - assert np.allclose(voltages[1, :], voltages[1, 0]) # all voltages should be the same + assert np.allclose( + voltages[1, :], voltages[1, 0] + ) # all voltages should be the same # Park beam vertically scanner.width = 2 @@ -154,8 +197,13 @@ def test_park_beam(bidirectional): img = scanner.read() assert img.shape == (1, 1) voltages = scanner._scan_pattern - assert np.allclose(voltages[1, :], voltages[1, 0]) # all voltages should be the same - assert np.allclose(voltages[0, :], voltages[0, 0]) # all voltages should be the same + assert np.allclose( + voltages[1, :], voltages[1, 0] + ) # all voltages should be the same + assert np.allclose( + voltages[0, :], voltages[0, 0] + ) # all voltages should be the same + # test zooming # ps = scanner.pixel_size diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 236ceba..5ae8443 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -21,8 +21,12 @@ def test_mock_camera_and_single_roi(): img = np.zeros((1000, 1000), dtype=np.int16) img[200, 300] = 39.39 # some random float src = Camera(StaticSource(img, 450 * u.nm)) - roi_detector = SingleRoi(src, pos=(200, 300), radius=0) # Only measure that specific point - assert roi_detector.read() == int(2 ** 16 - 1) # it should cast the array into some int + roi_detector = SingleRoi( + src, pos=(200, 300), radius=0 + ) # Only measure that specific point + assert roi_detector.read() == int( + 2**16 - 1 + ) # it should cast the array into some int @pytest.mark.parametrize("shape", [(1000, 1000), (999, 999)]) @@ -37,11 +41,13 @@ def test_microscope_without_magnification(shape): src = Camera(StaticSource(img, 400 * u.nm)) # construct microscope - sim = Microscope(source=src, magnification=1, numerical_aperture=1, wavelength=800 * u.nm) + sim = Microscope( + source=src, magnification=1, numerical_aperture=1, wavelength=800 * u.nm + ) cam = sim.get_camera() img = cam.read() - assert img[256, 256] == 2 ** 16 - 1 + assert img[256, 256] == 2**16 - 1 def test_microscope_and_aberration(): @@ -56,7 +62,13 @@ def test_microscope_and_aberration(): aberrations = skimage.data.camera() * ((2 * np.pi) / 255.0) - sim = Microscope(source=src, magnification=1, incident_field=slm.field, numerical_aperture=1, wavelength=800 * u.nm) + sim = Microscope( + source=src, + magnification=1, + incident_field=slm.field, + numerical_aperture=1, + wavelength=800 * u.nm, + ) without_aberration = sim.read()[256, 256] slm.set_phases(aberrations) @@ -78,10 +90,17 @@ def test_slm_and_aberration(): aberrations = skimage.data.camera() * ((2 * np.pi) / 255.0) * 0 slm.set_phases(-aberrations) - aberration = StaticSource(aberrations, pixel_size=1.0 / 512 * u.dimensionless_unscaled) - - sim1 = Microscope(source=src, incident_field=slm.field, numerical_aperture=1.0, aberrations=aberration, - wavelength=800 * u.nm) + aberration = StaticSource( + aberrations, pixel_size=1.0 / 512 * u.dimensionless_unscaled + ) + + sim1 = Microscope( + source=src, + incident_field=slm.field, + numerical_aperture=1.0, + aberrations=aberration, + wavelength=800 * u.nm, + ) sim2 = Microscope(source=src, numerical_aperture=1.0, wavelength=800 * u.nm) # We correlate the two. @@ -111,14 +130,19 @@ def test_slm_tilt(): slm = SLM(shape=(1000, 1000)) na = 1.0 - sim = Microscope(source=src, incident_field=slm.field, magnification=1, numerical_aperture=na, - wavelength=wavelength) + sim = Microscope( + source=src, + incident_field=slm.field, + magnification=1, + numerical_aperture=na, + wavelength=wavelength, + ) # introduce a tilted pupil plane # the input parameter to `tilt` corresponds to a shift 2.0/π the Abbe diffraction limit. shift = np.array((-24, 40)) step = wavelength / (np.pi * na) - slm.set_phases(tilt(1000, - shift * pixel_size / step)) + slm.set_phases(tilt(1000, -shift * pixel_size / step)) new_location = signal_location + shift @@ -135,7 +159,9 @@ def test_microscope_wavefront_shaping(caplog): # caplog.set_level(logging.DEBUG) aberrations = skimage.data.camera() * ((2 * np.pi) / 255.0) + np.pi - aberration = StaticSource(aberrations, pixel_size=1.0 / 512 * u.dimensionless_unscaled) # note: incorrect scaling! + aberration = StaticSource( + aberrations, pixel_size=1.0 / 512 * u.dimensionless_unscaled + ) # note: incorrect scaling! img = np.zeros((1000, 1000), dtype=np.int16) img[256, 256] = 100 @@ -148,13 +174,22 @@ def test_microscope_wavefront_shaping(caplog): slm = SLM(shape=(1000, 1000)) - sim = Microscope(source=src, incident_field=slm.field, numerical_aperture=1, aberrations=aberration, - wavelength=800 * u.nm) + sim = Microscope( + source=src, + incident_field=slm.field, + numerical_aperture=1, + aberrations=aberration, + wavelength=800 * u.nm, + ) cam = sim.get_camera(analog_max=100) - roi_detector = SingleRoi(cam, pos=signal_location, radius=0) # Only measure that specific point + roi_detector = SingleRoi( + cam, pos=signal_location, radius=0 + ) # Only measure that specific point - alg = StepwiseSequential(feedback=roi_detector, slm=slm, phase_steps=3, n_x=3, n_y=3) + alg = StepwiseSequential( + feedback=roi_detector, slm=slm, phase_steps=3, n_x=3, n_y=3 + ) t = alg.execute().t # test if the modes differ. The error causes them not to differ @@ -186,10 +221,36 @@ def test_mock_slm_lut_and_phase_response(): """ # === Test default lookup table and phase response === # Includes edge cases like rounding/wrapping: -0.501 -> 255, -0.499 -> 0 - input_phases_a = np.asarray( - (-1, -0.501, -0.499, 0, 1, 64, 128, 192, 255, 255.499, 255.501, 256, 257, 511, 512)) * 2 * np.pi / 256 - expected_output_phases_a = np.asarray( - (255, 255, 0, 0, 1, 64, 128, 192, 255, 255, 0, 0, 1, 255, 0)) * 2 * np.pi / 256 + input_phases_a = ( + np.asarray( + ( + -1, + -0.501, + -0.499, + 0, + 1, + 64, + 128, + 192, + 255, + 255.499, + 255.501, + 256, + 257, + 511, + 512, + ) + ) + * 2 + * np.pi + / 256 + ) + expected_output_phases_a = ( + np.asarray((255, 255, 0, 0, 1, 64, 128, 192, 255, 255, 0, 0, 1, 255, 0)) + * 2 + * np.pi + / 256 + ) slm1 = SLM(shape=(3, input_phases_a.shape[0])) slm1.set_phases(input_phases_a) assert np.all(np.abs(slm1.phases.read() - expected_output_phases_a) < 1e6) @@ -213,13 +274,22 @@ def test_mock_slm_lut_and_phase_response(): slm3 = SLM(shape=(3, 256)) slm3.lookup_table = lookup_table slm3.set_phases(linear_phase) - assert np.all(np.abs(slm3.phases.read() - inverse_phase_response_test_function(linear_phase, b, c, gamma)) < ( - 1.1 * np.pi / 256)) + assert np.all( + np.abs( + slm3.phases.read() + - inverse_phase_response_test_function(linear_phase, b, c, gamma) + ) + < (1.1 * np.pi / 256) + ) # === Test custom lookup table that counters custom synthetic phase response === - linear_phase_highres = np.arange(0, 2 * np.pi * 255.49 / 256, 0.25 * 2 * np.pi / 256) + linear_phase_highres = np.arange( + 0, 2 * np.pi * 255.49 / 256, 0.25 * 2 * np.pi / 256 + ) slm4 = SLM(shape=(3, linear_phase_highres.shape[0])) slm4.phase_response = phase_response slm4.lookup_table = lookup_table slm4.set_phases(linear_phase_highres) - assert np.all(np.abs(slm4.phases.read()[0] - linear_phase_highres) < (3 * np.pi / 256)) + assert np.all( + np.abs(slm4.phases.read()[0] - linear_phase_highres) < (3 * np.pi / 256) + ) diff --git a/tests/test_slm.py b/tests/test_slm.py index 5095a90..b4575d0 100644 --- a/tests/test_slm.py +++ b/tests/test_slm.py @@ -19,7 +19,7 @@ @pytest.fixture def slm() -> SLM: - slm = SLM(monitor_id=0, shape=(100, 200), pos=(20, 10), coordinate_system='full') + slm = SLM(monitor_id=0, shape=(100, 200), pos=(20, 10), coordinate_system="full") return slm @@ -30,7 +30,7 @@ def test_create_windowed(slm): assert slm.shape == (100, 200) assert slm.position == (20, 10) assert slm.transform == Transform.identity() - assert slm.coordinate_system == 'full' + assert slm.coordinate_system == "full" # check if frame buffer has correct size fb_texture = slm._frame_buffer._textures[Patch._PHASES_TEXTURE] @@ -110,7 +110,7 @@ def test_transform(slm): # now change the transform to 'short' to fit the pattern to a centered square, with the height of the # SLM. # Then check if the pattern is displayed correctly - slm.coordinate_system = 'short' # does not trigger an update + slm.coordinate_system = "short" # does not trigger an update assert np.all(slm.pixels.read() / 64 == pixels) slm.update() @@ -126,7 +126,7 @@ def test_transform(slm): # now change the transform to 'long' to fit the pattern to a centered square, with the width of the # SLM, causing part of the texture to be mapped outside the window. - slm.coordinate_system = 'long' # does not trigger an update + slm.coordinate_system = "long" # does not trigger an update assert np.all(slm.pixels.read() / 64 == pixels) slm.update() @@ -137,7 +137,7 @@ def test_transform(slm): assert np.allclose(pixels[:, 100:], 3) # test zooming the pattern - slm.coordinate_system = 'short' + slm.coordinate_system = "short" slm.transform = Transform.zoom(0.8) slm.update() @@ -153,8 +153,10 @@ def test_transform(slm): assert np.allclose(sub[20:, 40:], 3) -@pytest.mark.skip(reason="This test is skipped by default because it causes the screen to flicker, which may " - "affect people with epilepsy.") +@pytest.mark.skip( + reason="This test is skipped by default because it causes the screen to flicker, which may " + "affect people with epilepsy." +) def test_refresh_rate(): slm = SLM(1, latency=0, duration=0) refresh_rate = slm.refresh_rate @@ -171,14 +173,16 @@ def test_refresh_rate(): stop = time.time_ns() * u.ns del slm actual_refresh_rate = frame_count / (stop - start) - assert np.allclose(refresh_rate.to_value(u.Hz), actual_refresh_rate.to_value(u.Hz), rtol=1e-2) + assert np.allclose( + refresh_rate.to_value(u.Hz), actual_refresh_rate.to_value(u.Hz), rtol=1e-2 + ) def test_get_pixels(): width = 73 height = 99 slm = SLM(SLM.WINDOWED, shape=(height, width)) - slm.coordinate_system = 'full' # fill full screen exactly (anisotropic coordinates + slm.coordinate_system = "full" # fill full screen exactly (anisotropic coordinates pattern = np.random.uniform(size=(height, width)) * 2 * np.pi slm.set_phases(pattern) read_back = slm.pixels.read() @@ -257,8 +261,22 @@ def test_circular_geometry(slm): # read back the pixels and verify conversion to gray values pixels = np.rint(slm.pixels.read() / 256 * 70) - polar_pixels = cv2.warpPolar(pixels, (100, 40), (99.5, 99.5), 100, cv2.WARP_POLAR_LINEAR) - - assert np.allclose(polar_pixels[:, 3:24], np.repeat(np.flip(np.arange(0, 10)), 4).reshape((-1, 1)), atol=1) - assert np.allclose(polar_pixels[:, 27:47], np.repeat(np.flip(np.arange(10, 30)), 2).reshape((-1, 1)), atol=1) - assert np.allclose(polar_pixels[:, 53:97], np.repeat(np.flip(np.arange(30, 70)), 1).reshape((-1, 1)), atol=1) + polar_pixels = cv2.warpPolar( + pixels, (100, 40), (99.5, 99.5), 100, cv2.WARP_POLAR_LINEAR + ) + + assert np.allclose( + polar_pixels[:, 3:24], + np.repeat(np.flip(np.arange(0, 10)), 4).reshape((-1, 1)), + atol=1, + ) + assert np.allclose( + polar_pixels[:, 27:47], + np.repeat(np.flip(np.arange(10, 30)), 2).reshape((-1, 1)), + atol=1, + ) + assert np.allclose( + polar_pixels[:, 53:97], + np.repeat(np.flip(np.arange(30, 70)), 1).reshape((-1, 1)), + atol=1, + ) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index becf964..ea5d698 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -1,13 +1,21 @@ import numpy as np -from ..openwfs.utilities import set_pixel_size, get_pixel_size, place, Transform, project +from ..openwfs.utilities import ( + set_pixel_size, + get_pixel_size, + place, + Transform, + project, +) import astropy.units as u def test_to_matrix(): # Create a transform object - transform = Transform(transform=((1, 2), (3, 4)), - source_origin=(0.0, 0.0) * u.m, - destination_origin=(0.001, 0.002) * u.mm) + transform = Transform( + transform=((1, 2), (3, 4)), + source_origin=(0.0, 0.0) * u.m, + destination_origin=(0.001, 0.002) * u.mm, + ) # Define the expected output matrix for same input and output pixel sizes expected_matrix = ((1, 2, 1), (3, 4, 1)) @@ -26,23 +34,29 @@ def test_to_matrix(): src_center = np.array((0.5 * (src[1] - 1), 0.5 * (src[0] - 1), 1.0)) dst_center = np.array((0.5 * (dst[1] - 1), 0.5 * (dst[0] - 1))) transform = Transform() - result_matrix = transform.cv2_matrix(source_shape=src, - source_pixel_size=(1, 1), destination_shape=dst, - destination_pixel_size=(1, 1)) + result_matrix = transform.cv2_matrix( + source_shape=src, + source_pixel_size=(1, 1), + destination_shape=dst, + destination_pixel_size=(1, 1), + ) assert np.allclose(result_matrix @ src_center, dst_center) # Test center correction. The center of the source image should be mapped to the center of the destination image transform = Transform() # transform=((1, 2), (3, 4))) - result_matrix = transform.cv2_matrix(source_shape=src, - source_pixel_size=(0.5, 4) * u.um, destination_shape=dst, - destination_pixel_size=(1, 2) * u.um) + result_matrix = transform.cv2_matrix( + source_shape=src, + source_pixel_size=(0.5, 4) * u.um, + destination_shape=dst, + destination_pixel_size=(1, 2) * u.um, + ) assert np.allclose(result_matrix @ src_center, dst_center) # Also check openGL matrix (has y-axis flipped and extra row and column) expected_matrix = ((1, 2, 1), (3, 4, 2)) - transform = Transform(transform=((1, 2), (3, 4)), - source_origin=(0, 0), - destination_origin=(1, 2)) + transform = Transform( + transform=((1, 2), (3, 4)), source_origin=(0, 0), destination_origin=(1, 2) + ) result_matrix = transform.to_matrix((1, 1), (1, 1)) assert np.allclose(result_matrix, expected_matrix) @@ -96,16 +110,28 @@ def test_transform(): assert np.allclose(matrix, ((1.0, 0.0, 0.0), (0.0, 1.0, 0.0))) # shift both origins by same distance - t0 = Transform(source_origin=-ps1 * (1.7, 2.2), destination_origin=-ps1 * (1.7, 2.2)) - dst0 = project(src, source_extent=ps1 * np.array(src.shape), transform=t0, out_extent=ps1 * np.array(src.shape), - out_shape=src.shape) + t0 = Transform( + source_origin=-ps1 * (1.7, 2.2), destination_origin=-ps1 * (1.7, 2.2) + ) + dst0 = project( + src, + source_extent=ps1 * np.array(src.shape), + transform=t0, + out_extent=ps1 * np.array(src.shape), + out_shape=src.shape, + ) assert np.allclose(dst0, src) # shift source by (1,2) pixel t1 = Transform(source_origin=-ps1 * (1, 2)) dst1a = place(src.shape, ps1, src, offset=ps1 * (1, 2)) - dst1b = project(src, source_extent=ps1 * np.array(src.shape), transform=t1, out_extent=ps1 * np.array(src.shape), - out_shape=src.shape) + dst1b = project( + src, + source_extent=ps1 * np.array(src.shape), + transform=t1, + out_extent=ps1 * np.array(src.shape), + out_shape=src.shape, + ) assert np.allclose(dst1a, dst1b) @@ -121,7 +147,8 @@ def test_inverse(): transform = Transform( transform=((0.1, 0.2), (-0.25, 0.33)), source_origin=(0.12, 0.15), - destination_origin=(0.23, 0.33)) + destination_origin=(0.23, 0.33), + ) vector = (0.3, 0.4) result = transform.apply(vector) diff --git a/tests/test_wfs.py b/tests/test_wfs.py index 7a4febf..10244fe 100644 --- a/tests/test_wfs.py +++ b/tests/test_wfs.py @@ -7,12 +7,24 @@ from scipy.ndimage import zoom from skimage.transform import resize -from ..openwfs.algorithms import StepwiseSequential, FourierDualReference, FourierDualReferenceCircle, \ - CustomIterativeDualReference, troubleshoot +from ..openwfs.algorithms import ( + StepwiseSequential, + FourierDualReference, + FourierDualReferenceCircle, + CustomIterativeDualReference, + troubleshoot, +) from ..openwfs.algorithms.troubleshoot import field_correlation from ..openwfs.algorithms.utilities import WFSController from ..openwfs.processors import SingleRoi -from ..openwfs.simulation import SimulatedWFS, StaticSource, SLM, Microscope, ADCProcessor, Shutter +from ..openwfs.simulation import ( + SimulatedWFS, + StaticSource, + SLM, + Microscope, + ADCProcessor, + Shutter, +) from ..openwfs.utilities import set_pixel_size, tilt @@ -26,7 +38,9 @@ def assert_enhancement(slm, feedback, wfs_results, t_correct=None): ratio = after / before estimated_ratio = wfs_results.estimated_optimized_intensity / before print(f"expected: {estimated_ratio}, actual: {ratio}") - assert estimated_ratio * 0.5 <= ratio <= estimated_ratio * 2.0, f""" + assert ( + estimated_ratio * 0.5 <= ratio <= estimated_ratio * 2.0 + ), f""" The SSA algorithm did not enhance the focus as much as expected. Expected at least 0.5 * {estimated_ratio}, got {ratio}""" @@ -34,7 +48,10 @@ def assert_enhancement(slm, feedback, wfs_results, t_correct=None): # Check if we correctly measured the transmission matrix. # The correlation will be less for fewer segments, hence an (ad hoc) factor of 2/sqrt(n) t = wfs_results.t[:] - corr = np.abs(np.vdot(t_correct, t) / np.sqrt(np.vdot(t_correct, t_correct) * np.vdot(t, t))) + corr = np.abs( + np.vdot(t_correct, t) + / np.sqrt(np.vdot(t_correct, t_correct) * np.vdot(t, t)) + ) assert corr > 1.0 - 2.0 / np.sqrt(wfs_results.n) @@ -98,7 +115,13 @@ def test_ssa_noise(n_y, n_x): sim_no_noise = SimulatedWFS(aberrations=aberrations) slm = sim_no_noise.slm scale = np.max(sim_no_noise.read()) - sim = ADCProcessor(sim_no_noise, analog_max=scale * 200.0, digital_max=10000, shot_noise=True, generator=generator) + sim = ADCProcessor( + sim_no_noise, + analog_max=scale * 200.0, + digital_max=10000, + shot_noise=True, + generator=generator, + ) alg = StepwiseSequential(feedback=sim, slm=slm, n_x=n_x, n_y=n_y, phase_steps=10) result = alg.execute() print(result.fidelity_noise) @@ -129,8 +152,10 @@ def get_random_aberrations(): sim = SimulatedWFS(aberrations=get_random_aberrations(), slm=slm) # SSA - print(f'SSA run {r + 1}/{num_runs}') - alg_ssa = StepwiseSequential(feedback=sim, slm=sim.slm, n_x=13, n_y=13, phase_steps=6) + print(f"SSA run {r + 1}/{num_runs}") + alg_ssa = StepwiseSequential( + feedback=sim, slm=sim.slm, n_x=13, n_y=13, phase_steps=6 + ) wfs_result_ssa = alg_ssa.execute() sim.slm.set_phases(-np.angle(wfs_result_ssa.t)) shaped_intensities_ssa[r] = sim.read() @@ -140,7 +165,8 @@ def get_random_aberrations(): enhancement_ssa_std = shaped_intensities_ssa.std() / unshaped_intensities.mean() print( - f'SSA enhancement (squared signal): {enhancement_ssa:.2f}, std={enhancement_ssa_std:.2f}, with {wfs_result_ssa.n} modes') + f"SSA enhancement (squared signal): {enhancement_ssa:.2f}, std={enhancement_ssa_std:.2f}, with {wfs_result_ssa.n} modes" + ) assert enhancement_ssa > 100.0 @@ -153,9 +179,14 @@ def test_fourier(n_x): """ aberrations = skimage.data.camera() * (2.0 * np.pi / 255.0) sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_angles_min=-n_x, - k_angles_max=n_x, - phase_steps=4) + alg = FourierDualReference( + feedback=sim, + slm=sim.slm, + slm_shape=np.shape(aberrations), + k_angles_min=-n_x, + k_angles_max=n_x, + phase_steps=4, + ) results = alg.execute() assert_enhancement(sim.slm, sim, results, np.exp(1j * aberrations)) @@ -165,47 +196,69 @@ def test_fourier2(): slm_shape = (1000, 1000) aberrations = skimage.data.camera() * ((2 * np.pi) / 255.0) sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=slm_shape, k_angles_min=-5, - k_angles_max=5, - phase_steps=3) + alg = FourierDualReference( + feedback=sim, + slm=sim.slm, + slm_shape=slm_shape, + k_angles_min=-5, + k_angles_max=5, + phase_steps=3, + ) controller = WFSController(alg) controller.wavefront = WFSController.State.SHAPED_WAVEFRONT scaled_aberration = zoom(aberrations, np.array(slm_shape) / aberrations.shape) assert_enhancement(sim.slm, sim, controller._result, np.exp(1j * scaled_aberration)) -@pytest.mark.skip(reason="This test is is not passing yet and needs further inspection to see if the test itself is " - "correct.") +@pytest.mark.skip( + reason="This test is is not passing yet and needs further inspection to see if the test itself is " + "correct." +) def test_fourier3(): """Test the Fourier dual reference algorithm using WFSController.""" slm_shape = (32, 32) aberrations = np.random.uniform(0.0, 2 * np.pi, slm_shape) sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=slm_shape, k_angles_min=-32, - k_angles_max=32, - phase_steps=3) + alg = FourierDualReference( + feedback=sim, + slm=sim.slm, + slm_shape=slm_shape, + k_angles_min=-32, + k_angles_max=32, + phase_steps=3, + ) controller = WFSController(alg) controller.wavefront = WFSController.State.SHAPED_WAVEFRONT scaled_aberration = zoom(aberrations, np.array(slm_shape) / aberrations.shape) assert_enhancement(sim.slm, sim, controller._result, np.exp(1j * scaled_aberration)) -@pytest.mark.parametrize("k_radius, g", [[2.5, (1.0, 0.0)], [2.5, (0.0, 2.0)]], ) +@pytest.mark.parametrize( + "k_radius, g", + [[2.5, (1.0, 0.0)], [2.5, (0.0, 2.0)]], +) def test_fourier_circle(k_radius, g): """ Test Fourier dual reference algorithm with a circular k-space, with a tilt 'aberration'. """ aberrations = tilt(shape=(100, 100), extent=(2, 2), g=g, phase_offset=0.5) sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReferenceCircle(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_radius=k_radius, - phase_steps=4) + alg = FourierDualReferenceCircle( + feedback=sim, + slm=sim.slm, + slm_shape=np.shape(aberrations), + k_radius=k_radius, + phase_steps=4, + ) results = alg.execute() assert_enhancement(sim.slm, sim, results, np.exp(1j * aberrations)) def test_fourier_microscope(): aberration_phase = skimage.data.camera() * ((2 * np.pi) / 255.0) + np.pi - aberration = StaticSource(aberration_phase, pixel_size=2.0 / np.array(aberration_phase.shape)) + aberration = StaticSource( + aberration_phase, pixel_size=2.0 / np.array(aberration_phase.shape) + ) img = np.zeros((1000, 1000), dtype=np.int16) signal_location = (250, 250) img[signal_location] = 100 @@ -213,13 +266,24 @@ def test_fourier_microscope(): src = StaticSource(img, 400 * u.nm) slm = SLM(shape=(1000, 1000)) - sim = Microscope(source=src, incident_field=slm.field, magnification=1, numerical_aperture=1, - aberrations=aberration, - wavelength=800 * u.nm) + sim = Microscope( + source=src, + incident_field=slm.field, + magnification=1, + numerical_aperture=1, + aberrations=aberration, + wavelength=800 * u.nm, + ) cam = sim.get_camera(analog_max=100) roi_detector = SingleRoi(cam, pos=(250, 250)) # Only measure that specific point - alg = FourierDualReference(feedback=roi_detector, slm=slm, slm_shape=slm_shape, k_angles_min=-1, k_angles_max=1, - phase_steps=3) + alg = FourierDualReference( + feedback=roi_detector, + slm=slm, + slm_shape=slm_shape, + k_angles_min=-1, + k_angles_max=1, + phase_steps=3, + ) controller = WFSController(alg) controller.wavefront = WFSController.State.FLAT_WAVEFRONT before = roi_detector.read() @@ -227,8 +291,12 @@ def test_fourier_microscope(): after = roi_detector.read() # imshow(controller._optimized_wavefront) print(after / before) - scaled_aberration = zoom(aberration_phase, np.array(slm_shape) / aberration_phase.shape) - assert_enhancement(slm, roi_detector, controller._result, np.exp(1j * scaled_aberration)) + scaled_aberration = zoom( + aberration_phase, np.array(slm_shape) / aberration_phase.shape + ) + assert_enhancement( + slm, roi_detector, controller._result, np.exp(1j * scaled_aberration) + ) def test_fourier_correction_field(): @@ -237,13 +305,20 @@ def test_fourier_correction_field(): """ aberrations = skimage.data.camera() * (2.0 * np.pi / 255.0) sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_angles_min=-2, - k_angles_max=2, - phase_steps=3) + alg = FourierDualReference( + feedback=sim, + slm=sim.slm, + slm_shape=np.shape(aberrations), + k_angles_min=-2, + k_angles_max=2, + phase_steps=3, + ) t = alg.execute().t t_correct = np.exp(1j * aberrations) - correlation = np.vdot(t, t_correct) / np.sqrt(np.vdot(t, t) * np.vdot(t_correct, t_correct)) + correlation = np.vdot(t, t_correct) / np.sqrt( + np.vdot(t, t) * np.vdot(t_correct, t_correct) + ) # TODO: integrate with other test cases, duplication assert abs(correlation) > 0.75 @@ -256,9 +331,14 @@ def test_phase_shift_correction(): """ aberrations = skimage.data.camera() * (2.0 * np.pi / 255.0) sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_angles_min=-1, - k_angles_max=1, - phase_steps=3) + alg = FourierDualReference( + feedback=sim, + slm=sim.slm, + slm_shape=np.shape(aberrations), + k_angles_min=-1, + k_angles_max=1, + phase_steps=3, + ) t = alg.execute().t # compute the phase pattern to optimize the intensity in target 0 @@ -275,7 +355,9 @@ def test_phase_shift_correction(): signal = sim.read() signals.append(signal) - assert np.std(signals) < 0.0001 * before, f"""The simulated response of the Fourier algorithm is sensitive to a + assert ( + np.std(signals) < 0.0001 * before + ), f"""The simulated response of the Fourier algorithm is sensitive to a flat phase-shift. This is incorrect behaviour""" @@ -290,8 +372,14 @@ def test_flat_wf_response_fourier(): aberrations = np.zeros(shape=(512, 512)) sim = SimulatedWFS(aberrations=aberrations.reshape((*aberrations.shape, 1))) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_angles_min=-1, - k_angles_max=1, phase_steps=3) + alg = FourierDualReference( + feedback=sim, + slm=sim.slm, + slm_shape=np.shape(aberrations), + k_angles_min=-1, + k_angles_max=1, + phase_steps=3, + ) t = alg.execute().t @@ -314,7 +402,9 @@ def test_flat_wf_response_ssa(): # Assert that the standard deviation of the optimized wavefront is below the threshold, # indicating that it is effectively flat - assert np.std(optimised_wf) < 0.001, f"Response flat wavefront not flat, std: {np.std(optimised_wf)}" + assert ( + np.std(optimised_wf) < 0.001 + ), f"Response flat wavefront not flat, std: {np.std(optimised_wf)}" def test_multidimensional_feedback_ssa(): @@ -336,7 +426,9 @@ def test_multidimensional_feedback_ssa(): after = sim.read() enhancement = after / before - assert enhancement[target] >= 3.0, f"""The SSA algorithm did not enhance focus as much as expected. + assert ( + enhancement[target] >= 3.0 + ), f"""The SSA algorithm did not enhance focus as much as expected. Expected at least 3.0, got {enhancement}""" @@ -345,7 +437,9 @@ def test_multidimensional_feedback_fourier(): sim = SimulatedWFS(aberrations=aberrations) # input the camera as a feedback object, such that it is multidimensional - alg = FourierDualReference(feedback=sim, slm=sim.slm, k_angles_min=-1, k_angles_max=1, phase_steps=3) + alg = FourierDualReference( + feedback=sim, slm=sim.slm, k_angles_min=-1, k_angles_max=1, phase_steps=3 + ) t = alg.execute().t # compute the phase pattern to optimize the intensity in target 0 @@ -359,7 +453,9 @@ def test_multidimensional_feedback_fourier(): after = sim.read() enhancement = after / before - assert enhancement[2, 1] >= 3.0, f"""The algorithm did not enhance the focus as much as expected. + assert ( + enhancement[2, 1] >= 3.0 + ), f"""The algorithm did not enhance the focus as much as expected. Expected at least 3.0, got {enhancement}""" @@ -381,10 +477,17 @@ def test_ssa_fidelity(gaussian_noise_std): shutter = Shutter(slm.field) # Simulate a WFS microscope looking at the specimen - sim = Microscope(source=specimen, incident_field=shutter, aberrations=aberrations, wavelength=800 * u.nm) + sim = Microscope( + source=specimen, + incident_field=shutter, + aberrations=aberrations, + wavelength=800 * u.nm, + ) # Simulate a camera device with gaussian noise and shot noise - cam = sim.get_camera(analog_max=1e4, shot_noise=False, gaussian_noise_std=gaussian_noise_std) + cam = sim.get_camera( + analog_max=1e4, shot_noise=False, gaussian_noise_std=gaussian_noise_std + ) # Define feedback as circular region of interest in the center of the frame roi_detector = SingleRoi(cam, radius=1) @@ -393,16 +496,24 @@ def test_ssa_fidelity(gaussian_noise_std): # Use the stepwise sequential (SSA) WFS algorithm n_x = 10 n_y = 10 - alg = StepwiseSequential(feedback=roi_detector, slm=slm, n_x=n_x, n_y=n_y, phase_steps=8) + alg = StepwiseSequential( + feedback=roi_detector, slm=slm, n_x=n_x, n_y=n_y, phase_steps=8 + ) # Define a region of interest to determine average speckle intensity roi_background = SingleRoi(cam, radius=50) # Run WFS troubleshooter and output a report to the console - trouble = troubleshoot(algorithm=alg, background_feedback=roi_background, - frame_source=cam, shutter=shutter) + trouble = troubleshoot( + algorithm=alg, + background_feedback=roi_background, + frame_source=cam, + shutter=shutter, + ) - assert np.isclose(trouble.measured_enhancement, trouble.expected_enhancement, rtol=0.2) + assert np.isclose( + trouble.measured_enhancement, trouble.expected_enhancement, rtol=0.2 + ) def test_ssa_aberration_reconstruction(): @@ -415,7 +526,11 @@ def test_ssa_aberration_reconstruction(): # Create aberrations x = np.linspace(-1, 1, n_x).reshape((1, -1)) y = np.linspace(-1, 1, n_y).reshape((-1, 1)) - aberrations = (np.sin(0.8 * np.pi * x) * np.cos(1.3 * np.pi * y) * (0.8 * np.pi + 0.4 * x + 0.4 * y)) % (2 * np.pi) + aberrations = ( + np.sin(0.8 * np.pi * x) + * np.cos(1.3 * np.pi * y) + * (0.8 * np.pi + 0.4 * x + 0.4 * y) + ) % (2 * np.pi) aberrations[0:1, :] = 0 aberrations[:, 0:1] = 0 @@ -428,20 +543,25 @@ def test_ssa_aberration_reconstruction(): if do_debug: import matplotlib.pyplot as plt + plt.figure() - plt.imshow(np.angle(np.exp(1j * aberrations)), vmin=-np.pi, vmax=np.pi, cmap='hsv') - plt.title('Aberrations') + plt.imshow( + np.angle(np.exp(1j * aberrations)), vmin=-np.pi, vmax=np.pi, cmap="hsv" + ) + plt.title("Aberrations") plt.figure() - plt.imshow(np.angle(result.t), vmin=-np.pi, vmax=np.pi, cmap='hsv') - plt.title('t') + plt.imshow(np.angle(result.t), vmin=-np.pi, vmax=np.pi, cmap="hsv") + plt.title("t") plt.colorbar() plt.show() assert np.abs(field_correlation(np.exp(1j * aberrations), result.t)) > 0.99 -@pytest.mark.parametrize("construct_basis", (half_plane_wave_basis, half_hadamard_basis)) +@pytest.mark.parametrize( + "construct_basis", (half_plane_wave_basis, half_hadamard_basis) +) def test_custom_blind_dual_reference_ortho_split(construct_basis: callable): """Test custom blind dual reference with an orthonormal phase-only basis.""" do_debug = False @@ -458,11 +578,12 @@ def test_custom_blind_dual_reference_ortho_split(construct_basis: callable): if do_debug: # Plot the modes import matplotlib.pyplot as plt + plt.figure(figsize=(12, 7)) for m in range(M): plt.subplot(N2, N1, m + 1) plt.imshow(np.angle(mode_set[:, :, m]), vmin=-np.pi, vmax=np.pi) - plt.title(f'm={m}') + plt.title(f"m={m}") plt.xticks([]) plt.yticks([]) plt.pause(0.1) @@ -470,26 +591,38 @@ def test_custom_blind_dual_reference_ortho_split(construct_basis: callable): # Create aberrations x = np.linspace(-1, 1, 1 * N1).reshape((1, -1)) y = np.linspace(-1, 1, 1 * N1).reshape((-1, 1)) - aberrations = (np.sin(0.8 * np.pi * x) * np.cos(1.3 * np.pi * y) * (0.8 * np.pi + 0.4 * x + 0.4 * y)) % (2 * np.pi) + aberrations = ( + np.sin(0.8 * np.pi * x) + * np.cos(1.3 * np.pi * y) + * (0.8 * np.pi + 0.4 * x + 0.4 * y) + ) % (2 * np.pi) aberrations[0:2, :] = 0 aberrations[:, 0:2] = 0 sim = SimulatedWFS(aberrations=aberrations.reshape((*aberrations.shape, 1))) - alg = CustomIterativeDualReference(feedback=sim, slm=sim.slm, slm_shape=aberrations.shape, - phases=(phases_set, np.flip(phases_set, axis=1)), set1_mask=mask, phase_steps=4, - iterations=4) + alg = CustomIterativeDualReference( + feedback=sim, + slm=sim.slm, + slm_shape=aberrations.shape, + phases=(phases_set, np.flip(phases_set, axis=1)), + set1_mask=mask, + phase_steps=4, + iterations=4, + ) result = alg.execute() if do_debug: plt.figure() - plt.imshow(np.angle(np.exp(1j * aberrations)), vmin=-np.pi, vmax=np.pi, cmap='hsv') - plt.title('Aberrations') + plt.imshow( + np.angle(np.exp(1j * aberrations)), vmin=-np.pi, vmax=np.pi, cmap="hsv" + ) + plt.title("Aberrations") plt.figure() - plt.imshow(np.angle(result.t), vmin=-np.pi, vmax=np.pi, cmap='hsv') - plt.title('t') + plt.imshow(np.angle(result.t), vmin=-np.pi, vmax=np.pi, cmap="hsv") + plt.title("t") plt.colorbar() plt.show() @@ -506,7 +639,9 @@ def test_custom_blind_dual_reference_non_ortho(): N1 = 6 N2 = 3 M = N1 * N2 - mode_set_half = (1 / M) * (1j * np.eye(M).reshape((N1, N2, M)) * -np.ones(shape=(N1, N2, M))) + mode_set_half = (1 / M) * ( + 1j * np.eye(M).reshape((N1, N2, M)) * -np.ones(shape=(N1, N2, M)) + ) mode_set = np.concatenate((mode_set_half, np.zeros(shape=(N1, N2, M))), axis=1) phases_set = np.angle(mode_set) mask = np.concatenate((np.zeros((N1, N2)), np.ones((N1, N2))), axis=1) @@ -514,40 +649,53 @@ def test_custom_blind_dual_reference_non_ortho(): if do_debug: # Plot the modes import matplotlib.pyplot as plt + plt.figure(figsize=(12, 7)) for m in range(M): plt.subplot(N2, N1, m + 1) plt.imshow(np.angle(mode_set[:, :, m]), vmin=-np.pi, vmax=np.pi) - plt.title(f'm={m}') + plt.title(f"m={m}") plt.xticks([]) plt.yticks([]) plt.pause(0.01) - plt.suptitle('Phase of basis functions for one half') + plt.suptitle("Phase of basis functions for one half") # Create aberrations x = np.linspace(-1, 1, 1 * N1).reshape((1, -1)) y = np.linspace(-1, 1, 1 * N1).reshape((-1, 1)) - aberrations = (np.sin(0.8 * np.pi * x) * np.cos(1.3 * np.pi * y) * (0.8 * np.pi + 0.4 * x + 0.4 * y)) % (2 * np.pi) + aberrations = ( + np.sin(0.8 * np.pi * x) + * np.cos(1.3 * np.pi * y) + * (0.8 * np.pi + 0.4 * x + 0.4 * y) + ) % (2 * np.pi) aberrations[0:1, :] = 0 aberrations[:, 0:2] = 0 sim = SimulatedWFS(aberrations=aberrations.reshape((*aberrations.shape, 1))) - alg = CustomIterativeDualReference(feedback=sim, slm=sim.slm, slm_shape=aberrations.shape, - phases=(phases_set, np.flip(phases_set, axis=1)), set1_mask=mask, phase_steps=4, - iterations=4) + alg = CustomIterativeDualReference( + feedback=sim, + slm=sim.slm, + slm_shape=aberrations.shape, + phases=(phases_set, np.flip(phases_set, axis=1)), + set1_mask=mask, + phase_steps=4, + iterations=4, + ) result = alg.execute() if do_debug: plt.figure() - plt.imshow(np.angle(np.exp(1j * aberrations)), vmin=-np.pi, vmax=np.pi, cmap='hsv') - plt.title('Aberrations') + plt.imshow( + np.angle(np.exp(1j * aberrations)), vmin=-np.pi, vmax=np.pi, cmap="hsv" + ) + plt.title("Aberrations") plt.colorbar() plt.figure() - plt.imshow(np.angle(result.t), vmin=-np.pi, vmax=np.pi, cmap='hsv') - plt.title('t') + plt.imshow(np.angle(result.t), vmin=-np.pi, vmax=np.pi, cmap="hsv") + plt.title("t") plt.colorbar() plt.show() From c1283243fc87641e375f93ae0566b2ea61ce6421 Mon Sep 17 00:00:00 2001 From: Jeroen Doornbos Date: Sun, 29 Sep 2024 15:44:45 +0200 Subject: [PATCH 03/12] adapted styleguide --- STYLEGUIDE.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/STYLEGUIDE.md b/STYLEGUIDE.md index 0143572..5cace50 100644 --- a/STYLEGUIDE.md +++ b/STYLEGUIDE.md @@ -6,7 +6,7 @@ # General -- PyCharm autoformatting should be enabled to ensure correct formatting. +- the package `black` is used to ensure correct formatting. Install with `pip install black` and run in the terminal using `black .` when located at the root of the repository. # Tests From 863456f5afaf7ce9862f32aa6b1ae38315c801ef Mon Sep 17 00:00:00 2001 From: Jeroen Doornbos Date: Sun, 29 Sep 2024 15:45:24 +0200 Subject: [PATCH 04/12] Typo --- STYLEGUIDE.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/STYLEGUIDE.md b/STYLEGUIDE.md index 5cace50..0f814e0 100644 --- a/STYLEGUIDE.md +++ b/STYLEGUIDE.md @@ -6,7 +6,7 @@ # General -- the package `black` is used to ensure correct formatting. Install with `pip install black` and run in the terminal using `black .` when located at the root of the repository. +- The package `black` is used to ensure correct formatting. Install with `pip install black` and run in the terminal using `black .` when located at the root of the repository. # Tests From b742a758fe2cf82246ddf14262270e71d2756b01 Mon Sep 17 00:00:00 2001 From: Jeroen Doornbos Date: Sun, 29 Sep 2024 19:37:32 +0200 Subject: [PATCH 05/12] wording --- docs/source/readme.rst | 8 +++++--- docs/source/references.bib | 7 +++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/docs/source/readme.rst b/docs/source/readme.rst index 87c7d8e..2939e76 100644 --- a/docs/source/readme.rst +++ b/docs/source/readme.rst @@ -26,7 +26,7 @@ Wavefront shaping (WFS) is a technique for controlling the propagation of light It stands out that an important driving force in WFS is the development of new algorithms, for example to account for sample movement :cite:`valzania2023online`, to be optimally resilient to noise :cite:`mastiani2021noise`, or to use digital twin models to compute the required correction patterns :cite:`salter2014exploring,ploschner2015seeing,Thendiyammal2020,cox2023model`. Much progress has been made towards developing fast and noise-resilient algorithms, or algorithms designed for specific towards the methodology of wavefront shaping, such as using algorithms based on Hadamard patterns, or Fourier-based approaches :cite:`Mastiani2022`. Fast techniques that enable wavefront shaping in dynamic samples :cite:`Liu2017,Tzang2019`, and many potential applications have been developed and prototyped, including endoscopy :cite:`ploschner2015seeing`, optical trapping :cite:`Cizmar2010` and deep-tissue imaging :cite:`Streich2021`. -With the development of these advanced algorithms, however, the complexity of WFS software is gradually becoming a bottleneck for further advancements in the field, as well as for end-user adoption. Code for controlling wavefront shaping tends to be complex and setup-specific, and developing this code typically requires detailed technical knowledge and low-level programming. Moreover, since many labs use their own in-house programs to control the experiments, sharing and re-using code between different research groups is troublesome. +With the development of these advanced algorithms, however, the complexity of WFS software is steadily increasing as the field matures, which hinders cooperation as well as end-user adoption. Code for controlling wavefront shaping tends to be complex and setup-specific, and developing this code typically requires detailed technical knowledge and low-level programming. Moreover, since many labs use their own in-house programs to control the experiments, sharing and re-using code between different research groups is troublesome. What is OpenWFS? ---------------------- @@ -40,11 +40,13 @@ OpenWFS is a Python package for performing and for simulating wavefront shaping * **GenICam cameras**. The :class:`~.devices.Camera` object uses the `harvesters` backend :cite:`harvesters` to access any camera supporting the GenICam standard :cite:`genicam`. * **Automatic synchronization**. OpenWFS provides tools for automatic synchronization of actuators (e. g. an SLM) and detectors (e. g. a camera). The automatic synchronization makes it trivial to perform pipelined measurements that avoid the delay normally caused by the latency of the video card and SLM. -* **Wavefront shaping algorithms**. A (growing) collection of wavefront shaping algorithms. OpenWFS abstracts the hardware control, synchronization, and signal processing so that the user can focus on the algorithm itself. As a result, most algorithms can be implemented in just a few lines of code without the need for low-level or hardware-specific programming. +* **Wavefront shaping algorithms**. A (growing) collection of wavefront shaping algorithms. OpenWFS abstracts the hardware control, synchronization, and signal processing so that the user can focus on the algorithm itself. As a result, most algorithms can be implemented cleanly without hardware-specific programming. * **Simulation**. OpenWFS provides an extensive framework for testing and simulating wavefront shaping algorithms, including the effect of measurement noise, stage drift, and user-defined aberrations. This allows for rapid prototyping and testing of new algorithms, without the need for physical hardware. -* **Platform for exchange and joint collaboration**. OpenWFS can be used as a platform for sharing and exchanging wavefront shaping algorithms. The package is designed to be modular and easy to expand, and it is our hope that the community will contribute to the package by adding new algorithms, hardware control modules, and simulation tools. +* **Platform for exchange and joint collaboration**. OpenWFS can be used as a platform for sharing and exchanging wavefront shaping algorithms. The package is designed to be modular and easy to expand, and it is our hope that the community will contribute to the package by adding new algorithms, hardware control modules, and simulation tools. + +* **Platform for simplifying use of wavefront shaping**. OpenWFS is compatible to the recently developed PyDevice :cite:`PyDevice`, and can therefore be controlled from Micro-Manager :cite:`MMoverview`, a commonly used microscopy control platform. * **Automated troubleshooting**. OpenWFS provides tools for automated troubleshooting of wavefront shaping experiments. This includes tools for measuring the performance of wavefront shaping algorithms, and for identifying common problems such as incorrect SLM calibration, drift, measurement noise, and other experimental imperfections. diff --git a/docs/source/references.bib b/docs/source/references.bib index fe4be20..a441e31 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -452,3 +452,10 @@ @book{kubby2019 editor = {Kubby, Joel and Gigan, Sylvain and Cui, Meng}, year = {2019}, collection = {Advances in Microscopy and Microanalysis} } + +@misc{PyDevice, + author = {Ivo Vellekoop and Jeroen Doornbos}, + title = {Micro-Manager PyDevice}, + url = {https://micro-manager.org/PyDevice}, +} + From b2ffb8c2a6541cc841a656881a77a8e6288281d7 Mon Sep 17 00:00:00 2001 From: Jeroen Doornbos Date: Sun, 29 Sep 2024 22:20:06 +0200 Subject: [PATCH 06/12] more extensive review --- docs/source/readme.rst | 6 +-- docs/source/references.bib | 89 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 4 deletions(-) diff --git a/docs/source/readme.rst b/docs/source/readme.rst index 2939e76..9305ff8 100644 --- a/docs/source/readme.rst +++ b/docs/source/readme.rst @@ -17,14 +17,12 @@ OpenWFS [![PyTest](https://github.com/IvoVellekoop/openwfs/actions/workflows/pytest.yml/badge.svg)](https://github.com/IvoVellekoop/openwfs/actions/workflows/pytest.yml) [![Black](https://github.com/IvoVellekoop/openwfs/actions/workflows/black.yml/badge.svg)](https://github.com/IvoVellekoop/openwfs/actions/workflows/black.yml) -What is wavefront shaping? - What is wavefront shaping? -------------------------------- -Wavefront shaping (WFS) is a technique for controlling the propagation of light in arbitrarily complex structures, including strongly scattering materials :cite:`kubby2019`. In WFS, a spatial light modulator (SLM) is used to shape the phase and/or amplitude of the incident light. With a properly constructed wavefront, light can be made to focus through :cite:`Vellekoop2007`, or inside :cite:`vellekoop2008demixing` scattering materials; or light can be shaped to have other desired properties, such as optimal sensitivity for specific measurements :cite:`bouchet2021maximum`, specialized point-spread functions :cite:`boniface2017transmission` or for functions like optical trapping :cite:`vcivzmar2010situ`. +Wavefront shaping (WFS) is a technique for controlling the propagation of light in arbitrarily complex structures, including strongly scattering materials :cite:`kubby2019`. In WFS, a spatial light modulator (SLM) is used to shape the phase and/or amplitude of the incident light. With a properly constructed wavefront, light can be made to focus through :cite:`Vellekoop2007`, or inside :cite:`vellekoop2008demixing` scattering materials; or light can be shaped to have other desired properties, such as optimal sensitivity for specific measurements :cite:`bouchet2021maximum`, specialized point-spread functions :cite:`boniface2017transmission`, spectral filtering :cite:`Park2012`,, or for functions like optical trapping :cite:`vcivzmar2010situ`. -It stands out that an important driving force in WFS is the development of new algorithms, for example to account for sample movement :cite:`valzania2023online`, to be optimally resilient to noise :cite:`mastiani2021noise`, or to use digital twin models to compute the required correction patterns :cite:`salter2014exploring,ploschner2015seeing,Thendiyammal2020,cox2023model`. Much progress has been made towards developing fast and noise-resilient algorithms, or algorithms designed for specific towards the methodology of wavefront shaping, such as using algorithms based on Hadamard patterns, or Fourier-based approaches :cite:`Mastiani2022`. Fast techniques that enable wavefront shaping in dynamic samples :cite:`Liu2017,Tzang2019`, and many potential applications have been developed and prototyped, including endoscopy :cite:`ploschner2015seeing`, optical trapping :cite:`Cizmar2010` and deep-tissue imaging :cite:`Streich2021`. +It stands out that an important driving force in WFS is the development of new algorithms, for example to account for sample movement :cite:`valzania2023online`, experimental conditions :cite:`Anderson2016`, to be optimally resilient to noise :cite:`mastiani2021noise`, or to use digital twin models to compute the required correction patterns :cite:`salter2014exploring,ploschner2015seeing,Thendiyammal2020,cox2023model`. Much progress has been made towards developing fast and noise-resilient algorithms, or algorithms designed for specific towards the methodology of wavefront shaping, such as using algorithms based on Hadamard patterns, or Fourier-based approaches :cite:`Mastiani2022`. Fast techniques that enable wavefront shaping in dynamic samples :cite:`Liu2017,Tzang2019`, and many potential applications have been developed and prototyped, including endoscopy :cite:`ploschner2015seeing`, optical trapping :cite:`Cizmar2010`, Raman scattering, :cite:`Thompson2016`, and deep-tissue imaging :cite:`Streich2021`. Applications extend beyond that of microscope imaging such as optimizing photoelectrochemical absorption :cite:`Liew2016` and tuning random lasers :cite:`Bachelard2014`. With the development of these advanced algorithms, however, the complexity of WFS software is steadily increasing as the field matures, which hinders cooperation as well as end-user adoption. Code for controlling wavefront shaping tends to be complex and setup-specific, and developing this code typically requires detailed technical knowledge and low-level programming. Moreover, since many labs use their own in-house programs to control the experiments, sharing and re-using code between different research groups is troublesome. diff --git a/docs/source/references.bib b/docs/source/references.bib index a441e31..7398070 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -157,6 +157,24 @@ @article{astropy primaryClass = {astro-ph.IM}, } +@article{Thompson2016, + abstract = {Spontaneous Raman scattering is a powerful tool for chemical sensing and imaging but suffers from a weak signal. In this Letter, we present an application of adaptive optics to enhance the Raman scattering signal detected through a turbid, optically thick material. This technique utilizes recent advances in wavefront shaping techniques for focusing light through a turbid media and applies them to chemical detection to achieve a signal enhancement with little sacrifice to the overall simplicity of the experimental setup. With this technique, we demonstrate an enhancement in the Raman signal from titanium dioxide particles through a highly scattering material. This technique may pave the way to label-free tracking using the optical memory effect.}, + author = {Jonathan V. Thompson and Graham A. Throckmorton and Brett H. Hokr and Vladislav V. Yakovlev}, + doi = {10.1364/OL.41.001769}, + issn = {1539-4794}, + issue = {8}, + journal = {Optics letters}, + keywords = {Graham A Throckmorton,Jonathan V Thompson,MEDLINE,NCBI,NIH,NLM,National Center for Biotechnology Information,National Institutes of Health,National Library of Medicine,Non-P.H.S.,PubMed Abstract,Radiation*,Raman*,Research Support,Scattering,Spectrum Analysis,U.S. Gov't,Vladislav V Yakovlev,doi:10.1364/OL.41.001769,pmid:27082341}, + month = {4}, + pages = {1769}, + pmid = {27082341}, + publisher = {Opt Lett}, + title = {Wavefront shaping enhanced Raman scattering in a turbid medium}, + volume = {41}, + url = {https://pubmed.ncbi.nlm.nih.gov/27082341/}, + year = {2016}, +} + @article{vellekoop2008demixing, title = {Demixing light paths inside disordered metamaterials}, @@ -364,6 +382,41 @@ @article{Streich2021 year = {2021}, } +@article{Liew2016, + abstract = {A fundamental issue that limits the efficiency of many photoelectrochemical systems is that the photon absorption length is typically much longer than the electron diffusion length. Various photon management schemes have been developed to enhance light absorption; one simple approach is to use randomly scattering media to enable broadband and wide-angle enhancement. However, such systems are often opaque, making it difficult to probe photoinduced processes. Here we use wave interference effects to modify the spatial distribution of light inside a highly scattering dye-sensitized solar cell to control photon absorption in a space-dependent manner. By shaping the incident wavefront of a laser beam, we enhance or suppress photocurrent by increasing or decreasing light concentration on the front side of the mesoporous photoanode where the collection efficiency of photoelectrons is maximal. Enhanced light absorption is achieved by reducing reflection through the open boundary of the photoanode via destructive interference, leading to a factor of 2 increase in photocurrent. This approach opens the door to probing and manipulating photoelectrochemical processes in specific regions inside nominally opaque media.}, + author = {Seng Fatt Liew and Sébastien M. Popoff and Stafford W. Sheehan and Arthur Goetschy and Charles A. Schmuttenmaer and A. Douglas Stone and Hui Cao}, + doi = {10.1021/ACSPHOTONICS.5B00642}, + issn = {2330-4022}, + issue = {3}, + journal = {ACS Photonics}, + keywords = {dye-sensitized solar cells,multimode interference,multiple scattering,photoelectrochemical,wavefront shaping}, + month = {3}, + pages = {449-455}, + publisher = {American Chemical Society}, + title = {Coherent Control of Photocurrent in a Strongly Scattering Photoelectrochemical System}, + volume = {3}, + url = {https://technion-staging.elsevierpure.com/en/publications/coherent-control-of-photocurrent-in-a-strongly-scattering-photoel}, + year = {2016}, +} + + +@article{Anderson2016, + abstract = {Previously we considered the effect of experimental parameters on optimized transmission through opaque media using spatial light modulator (SLM)-based wavefront shaping. In this study we consider the opposite geometry, in which we optimize reflection from an opaque surface such that the backscattered light is focused onto a spot on an imaging detector. By systematically varying different experimental parameters (genetic algorithm iterations, bin size, SLM active area, target area, spot size, and sample angle with respect to the optical axis) and optimizing the reflected light we determine how each parameter affects the intensity enhancement. We find that the effects of the experimental parameters on the enhancement are similar to those measured for a transmissive geometry, but with the exact functional forms changed due to the different geometry and the use of a genetic algorithm instead of an iterative algorithm. Additionally, we find preliminary evidence of greater enhancements than predicted by random matrix theory, suggesting a possibly new physical mechanism to be investigated in future work.}, + author = {Benjamin R. Anderson and Ray Gunawidjaja and Hergen Eilers}, + doi = {10.1103/PHYSREVA.93.013813/FIGURES/12/MEDIUM}, + issn = {24699934}, + issue = {1}, + journal = {Physical Review A}, + month = {1}, + pages = {013813}, + publisher = {American Physical Society}, + title = {Effect of experimental parameters on optimal reflection of light from opaque media}, + volume = {93}, + url = {https://journals.aps.org/pra/abstract/10.1103/PhysRevA.93.013813}, + year = {2016}, +} + + @misc{MMoverview, author = {Mark Tsuchida and Sam Griffin}, title = {Micro-Manager Project Overview}, @@ -375,6 +428,42 @@ @misc{openwfsdocumentation url = {https://openwfs.readthedocs.io/en/latest/}, } +@article{Bachelard2014, + abstract = {A laser is not necessarily a sophisticated device: pumping an amplifying medium randomly filled with scatterers makes a perfectly viable â ̃ random laserâ ™. The absence of mirrors greatly simplifies laser design, but control over the emission wavelength and directionality is lost, seriously hindering prospects for this otherwise simple laser. Recently, we proposed an approach to tame random lasers, inspired by coherent light control in complex media. Here, we implement this method in an optofluidic random laser where modes are spatially extended and overlap, making individual mode selection impossible, a priori. We show experimentally that control over laser emission can be regained even in this extreme case. By actively shaping the optical pump within the random laser, single-mode operation at any selected wavelength is achieved with spectral selectivity down to 0.06 nm and more than 10 dB side-lobe rejection. This method paves the way towards versatile tunable and controlled random lasers as well as the taming of other laser sources. © 2014 Macmillan Publishers Limited. All rights reserved.}, + author = {Nicolas Bachelard and Sylvain Gigan and Xavier Noblin and Patrick Sebbah}, + doi = {10.1038/nphys2939}, + issn = {17452481}, + issue = {6}, + journal = {Nature Physics}, + keywords = {Optics,Physics}, + pages = {426-431}, + publisher = {Nature Publishing Group}, + title = {Adaptive pumping for spectral control of random lasers}, + volume = {10}, + url = {https://ui.adsabs.harvard.edu/abs/2014NatPh..10..426B/abstract}, + year = {2014}, +} + + +@article{Park2012, + abstract = {We demonstrate controlled wavelength-dependent light focusing through turbid media using wavefront shaping. Due to the dispersion caused by multiple light scattering, light propagation through turbid media can be independently controlled between different wavelengths. Foci with various wavelengths can be generated by applying different optimized wavefronts to a highly scattering layer. Given the linearity of the transmission matrix, multiple foci with different wavelengths can also be simultaneously constructed.}, + author = {Jung-Hoon Park and ChungHyun Park and YongKeun Park and Hyunseung Yu and Yong-Hoon Cho}, + doi = {10.1364/OL.37.003261}, + issn = {1539-4794}, + issue = {15}, + journal = {Optics Letters, Vol. 37, Issue 15, pp. 3261-3263}, + keywords = {Light propagation,Multiple scattering,Scattering media,Second harmonic generation,Spatial light modulators,Turbid media}, + month = {8}, + pages = {3261-3263}, + pmid = {22859152}, + publisher = {Optica Publishing Group}, + title = {Active spectral filtering through turbid media}, + volume = {37}, + url = {https://opg.optica.org/viewmedia.cfm?uri=ol-37-15-3261&seq=0&html=true https://opg.optica.org/abstract.cfm?uri=ol-37-15-3261 https://opg.optica.org/ol/abstract.cfm?uri=ol-37-15-3261}, + year = {2012}, +} + + @misc{pydevice, title = {{PyDevice} {GitHub} repository}, url = {https://www.github.com/IvoVellekoop/pydevice}, From e2352104ed49b9e7f38dcfec164d8ec85903ec51 Mon Sep 17 00:00:00 2001 From: Jeroen Doornbos Date: Sun, 29 Sep 2024 23:03:31 +0200 Subject: [PATCH 07/12] further writing --- docs/source/development.rst | 4 +++- docs/source/readme.rst | 6 +++--- docs/source/references.bib | 18 ++++++++++++++++++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/docs/source/development.rst b/docs/source/development.rst index f2b1a1c..7b57fcb 100644 --- a/docs/source/development.rst +++ b/docs/source/development.rst @@ -14,7 +14,9 @@ To download the source code, including tests and examples, clone the repository poetry install --with dev --with docs poetry run pytest -The examples are located in the ``examples`` directory. Note that a lot of functionality is also demonstrated in the automatic tests located in the ``tests`` directory. As an alternative to downloading the source code, the samples can also be copied directly from the example gallery on the documentation website :cite:`readthedocsOpenWFS`. +The examples are located in the ``examples`` directory. Note that a lot of functionality is also demonstrated in the automatic tests located in the ``tests`` directory. As an alternative to downloading the source code, the samples can also be copied directly from the example gallery on the documentation website :cite:`readthedocsOpenWFS`. + +Important to note for adding hardware devices, is that many of the components rely on third-party, in some case proprietary drivers. For using NI DAQ components, the nidaqmx package needs to be installed, and for openCV and Genicam their respective drivers need to be installed. The specific requirements are always listed in the documentation of the functions and classes that require packages like these. Building the documentation -------------------------------------------------- diff --git a/docs/source/readme.rst b/docs/source/readme.rst index 9305ff8..65cf83d 100644 --- a/docs/source/readme.rst +++ b/docs/source/readme.rst @@ -24,7 +24,7 @@ Wavefront shaping (WFS) is a technique for controlling the propagation of light It stands out that an important driving force in WFS is the development of new algorithms, for example to account for sample movement :cite:`valzania2023online`, experimental conditions :cite:`Anderson2016`, to be optimally resilient to noise :cite:`mastiani2021noise`, or to use digital twin models to compute the required correction patterns :cite:`salter2014exploring,ploschner2015seeing,Thendiyammal2020,cox2023model`. Much progress has been made towards developing fast and noise-resilient algorithms, or algorithms designed for specific towards the methodology of wavefront shaping, such as using algorithms based on Hadamard patterns, or Fourier-based approaches :cite:`Mastiani2022`. Fast techniques that enable wavefront shaping in dynamic samples :cite:`Liu2017,Tzang2019`, and many potential applications have been developed and prototyped, including endoscopy :cite:`ploschner2015seeing`, optical trapping :cite:`Cizmar2010`, Raman scattering, :cite:`Thompson2016`, and deep-tissue imaging :cite:`Streich2021`. Applications extend beyond that of microscope imaging such as optimizing photoelectrochemical absorption :cite:`Liew2016` and tuning random lasers :cite:`Bachelard2014`. -With the development of these advanced algorithms, however, the complexity of WFS software is steadily increasing as the field matures, which hinders cooperation as well as end-user adoption. Code for controlling wavefront shaping tends to be complex and setup-specific, and developing this code typically requires detailed technical knowledge and low-level programming. Moreover, since many labs use their own in-house programs to control the experiments, sharing and re-using code between different research groups is troublesome. +With the development of these advanced algorithms, however, the complexity of WFS software is steadily increasing as the field matures, which hinders cooperation as well as end-user adoption. Code for controlling wavefront shaping tends to be complex and setup-specific, and developing this code typically requires detailed technical knowledge and low-level programming. A recent c++ based contribution :cite:`Anderson2024`, highlights the growing need for software based tools that enable use and development. Moreover, since many labs use their own in-house programs to control the experiments, sharing and re-using code between different research groups is troublesome. What is OpenWFS? ---------------------- @@ -42,9 +42,9 @@ OpenWFS is a Python package for performing and for simulating wavefront shaping * **Simulation**. OpenWFS provides an extensive framework for testing and simulating wavefront shaping algorithms, including the effect of measurement noise, stage drift, and user-defined aberrations. This allows for rapid prototyping and testing of new algorithms, without the need for physical hardware. -* **Platform for exchange and joint collaboration**. OpenWFS can be used as a platform for sharing and exchanging wavefront shaping algorithms. The package is designed to be modular and easy to expand, and it is our hope that the community will contribute to the package by adding new algorithms, hardware control modules, and simulation tools. +* **Platform for exchange and joint collaboration**. OpenWFS can be used as a platform for sharing and exchanging wavefront shaping algorithms. The package is designed to be modular and easy to expand, and it is our hope that the community will contribute to the package by adding new algorithms, hardware control modules, and simulation tools. Python was specifically chosen for this purpose for its active community, high level of abstraction and the ease of sharing tools. Further expansion of the supported hardware is of high priority, especially wrapping c-based software support with tools like ctypes and the Micro-Manager based device adapters. -* **Platform for simplifying use of wavefront shaping**. OpenWFS is compatible to the recently developed PyDevice :cite:`PyDevice`, and can therefore be controlled from Micro-Manager :cite:`MMoverview`, a commonly used microscopy control platform. +* **Platform for simplifying use of wavefront shaping**. OpenWFS is compatible with the recently developed PyDevice :cite:`PyDevice`, and can therefore be controlled from Micro-Manager :cite:`MMoverview`, a commonly used microscopy control platform. * **Automated troubleshooting**. OpenWFS provides tools for automated troubleshooting of wavefront shaping experiments. This includes tools for measuring the performance of wavefront shaping algorithms, and for identifying common problems such as incorrect SLM calibration, drift, measurement noise, and other experimental imperfections. diff --git a/docs/source/references.bib b/docs/source/references.bib index 7398070..998292c 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -428,6 +428,24 @@ @misc{openwfsdocumentation url = {https://openwfs.readthedocs.io/en/latest/}, } +@article{Anderson2024, + abstract = {We have developed a modular graphical user interface (GUI)-based program for use in genetic algorithm-based feedback-assisted wavefront shaping. The program uses a class-based structure to separate out the universal modules (e.g. GUI, multithreading, optimization algorithms) and hardware-specific modules (e.g. code for different SLMs and cameras). This modular design makes the program easily adaptable to a wide range of lab equipment, while providing easy access to a GUI, multithreading, and three optimization algorithms (phase-stepping, simple genetic, and microgenetic).}, + author = {Benjamin R. Anderson and Andrew O’Kins and Kostiantyn Makrasnov and Rebecca Udby and Patrick Price and Hergen Eilers}, + doi = {10.1088/2515-7647/AD6ED3}, + issn = {2515-7647}, + issue = {4}, + journal = {Journal of Physics: Photonics}, + keywords = {disordered media,genetic algorithms,multithreading,spatial light modulators,wavefront shaping}, + month = {8}, + pages = {045008}, + publisher = {IOP Publishing}, + title = {A modular GUI-based program for genetic algorithm-based feedback-assisted wavefront shaping}, + volume = {6}, + url = {https://iopscience.iop.org/article/10.1088/2515-7647/ad6ed3 https://iopscience.iop.org/article/10.1088/2515-7647/ad6ed3/meta}, + year = {2024}, +} + + @article{Bachelard2014, abstract = {A laser is not necessarily a sophisticated device: pumping an amplifying medium randomly filled with scatterers makes a perfectly viable â ̃ random laserâ ™. The absence of mirrors greatly simplifies laser design, but control over the emission wavelength and directionality is lost, seriously hindering prospects for this otherwise simple laser. Recently, we proposed an approach to tame random lasers, inspired by coherent light control in complex media. Here, we implement this method in an optofluidic random laser where modes are spatially extended and overlap, making individual mode selection impossible, a priori. We show experimentally that control over laser emission can be regained even in this extreme case. By actively shaping the optical pump within the random laser, single-mode operation at any selected wavelength is achieved with spectral selectivity down to 0.06 nm and more than 10 dB side-lobe rejection. This method paves the way towards versatile tunable and controlled random lasers as well as the taming of other laser sources. © 2014 Macmillan Publishers Limited. All rights reserved.}, author = {Nicolas Bachelard and Sylvain Gigan and Xavier Noblin and Patrick Sebbah}, From df3312f341bdcae3cd4faf93a4b355ef020ffca6 Mon Sep 17 00:00:00 2001 From: Daniel Cox Date: Tue, 1 Oct 2024 14:27:29 +0200 Subject: [PATCH 08/12] Add plot functionality --- .../algorithms/custom_iter_dual_reference.py | 220 ++++++++++++++++++ openwfs/plot_utilities.py | 156 +++++++++++++ tests/test_wfs.py | 28 +-- 3 files changed, 388 insertions(+), 16 deletions(-) create mode 100644 openwfs/algorithms/custom_iter_dual_reference.py diff --git a/openwfs/algorithms/custom_iter_dual_reference.py b/openwfs/algorithms/custom_iter_dual_reference.py new file mode 100644 index 0000000..73b083b --- /dev/null +++ b/openwfs/algorithms/custom_iter_dual_reference.py @@ -0,0 +1,220 @@ +from typing import Optional + +import numpy as np +from numpy import ndarray as nd + +from .utilities import analyze_phase_stepping, WFSResult +from ..core import Detector, PhaseSLM + + +def weighted_average(a, b, wa, wb): + """ + Compute the weighted average of two values. + + Args: + a: The first value. + b: The second value. + wa: The weight of the first value. + wb: The weight of the second value. + """ + return (a * wa + b * wb) / (wa + wb) + + +class IterativeDualReference: + """ + A generic iterative dual reference WFS algorithm, which can use a custom set of basis functions. + + This algorithm is adapted from [1], with the addition of the ability to use custom basis functions and specify the number of iterations. + + In this algorithm, the SLM pixels are divided into two groups: A and B, as indicated by the boolean group_mask argument. + The algorithm first keeps the pixels in group B fixed, and displays a sequence on patterns on the pixels of group A. + It uses these measurements to construct an optimized wavefront that is displayed on the pixels of group A. + This process is then repeated for the pixels of group B, now using the *optimized* wavefront on group A as reference. + Optionally, the process can be repeated for a number of iterations, which each iteration using the current correction + pattern as a reference. This makes this algorithm suitable for non-linear feedback, such as multi-photon + excitation fluorescence [2]. + + This algorithm assumes a phase-only SLM. Hence, the input modes are defined by passing the corresponding phase + patterns (in radians) as input argument. + + [1]: X. Tao, T. Lam, B. Zhu, et al., “Three-dimensional focusing through scattering media using conjugate adaptive + optics with remote focusing (CAORF),” Opt. Express 25, 10368–10383 (2017). + + [2]: Gerwin Osnabrugge, Lyubov V. Amitonova, and Ivo M. Vellekoop. "Blind focusing through strongly scattering media + using wavefront shaping with nonlinear feedback", Optics Express, 27(8):11673–11688, 2019. + https://opg.optica.org/oe/ abstract.cfm?uri=oe-27-8-1167 + """ + + def __init__(self, feedback: Detector, slm: PhaseSLM, phase_patterns: tuple[nd, nd], group_mask: nd, + phase_steps: int = 4, iterations: int = 4, analyzer: Optional[callable] = analyze_phase_stepping): + """ + Args: + feedback: The feedback source, usually a detector that provides measurement data. + slm: Spatial light modulator object. + phase_patterns: A tuple of two 3D arrays, containing the phase patterns for group A and group B, respectively. + The first two dimensions are the spatial dimensions, and should match the size of group_mask. + The 3rd dimension in the array is index of the phase pattern. The number of phase patterns in A and B may be different. + group_mask: A 2D bool array of that defines the pixels used by group A with False and elements used by + group B with True. + phase_steps: The number of phase steps for each mode (default is 4). Depending on the type of + non-linear feedback and the SNR, more might be required. + iterations: Number of times to measure a mode set, e.g. when iterations = 5, the measurements are + A, B, A, B, A. Should be at least 2 + analyzer: The function used to analyze the phase stepping data. Must return a WFSResult object. Defaults to `analyze_phase_stepping` + """ + if (phase_patterns[0].shape[0:2] != group_mask.shape) or (phase_patterns[1].shape[0:2] != group_mask.shape): + raise ValueError("The phase patterns and group mask must all have the same shape.") + if iterations < 2: + raise ValueError("The number of iterations must be at least 2.") + if np.prod(feedback.data_shape) != 1: + raise ValueError("The feedback detector should return a single scalar value.") + + self.slm = slm + self.feedback = feedback + self.phase_steps = phase_steps + self.iterations = iterations + self.analyzer = analyzer + self.phase_patterns = (phase_patterns[0].astype(np.float32), phase_patterns[1].astype(np.float32)) + mask = group_mask.astype(bool) + self.masks = (~mask, mask) # masks[0] is True for group A, mask[1] is True for group B + + # Pre-compute the conjugate modes for reconstruction + self.modes = [np.exp(-1j * self.phase_patterns[side]) * np.expand_dims(self.masks[side], axis=2) for side in + range(2)] + + def execute(self, capture_intermediate_results: bool = False, progress_bar=None) -> WFSResult: + """ + Executes the blind focusing dual reference algorithm and compute the SLM transmission matrix. + capture_intermediate_results: When True, measures the feedback from the optimized wavefront after each iteration. + This can be useful to determine how many iterations are needed to converge to an optimal pattern. + This data is stored as the 'intermediate_results' field in the results + progress_bar: Optional progress bar object. Following the convention for tqdm progress bars, + this object should have a `total` attribute and an `update()` function. + + Returns: + WFSResult: An object containing the computed SLM transmission matrix and related data. The amplitude profile + of each mode is assumed to be 1. If a different amplitude profile is desired, this can be obtained by + multiplying that amplitude profile with this transmission matrix. + """ + + # Current estimate of the transmission matrix (start with all 0) + t_full = np.zeros(shape=self.modes[0].shape[0:2]) + t_other_side = t_full + + # Initialize storage lists + t_set_all = [None] * self.iterations + results_all = [None] * self.iterations # List to store all results + results_latest = [None, None] # The two latest results. Used for computing fidelity factors. + intermediate_results = np.zeros(self.iterations) # List to store feedback from full patterns + + # Prepare progress bar + if progress_bar: + num_measurements = np.ceil(self.iterations / 2) * self.modes[0].shape[2] \ + + np.floor(self.iterations / 2) * self.modes[1].shape[2] + progress_bar.total = num_measurements + + # Switch the phase sets back and forth multiple times + for it in range(self.iterations): + side = it % 2 # pick set A or B for phase stepping + ref_phases = -np.angle(t_full) # use the best estimate so far to construct an optimized reference + side_mask = self.masks[side] + # Perform WFS experiment on one side, keeping the other side sized at the ref_phases + result = self._single_side_experiment(mod_phases=self.phase_patterns[side], ref_phases=ref_phases, + mod_mask=side_mask, progress_bar=progress_bar) + + # Compute transmission matrix for the current side and update + # estimated transmission matrix + t_this_side = self.compute_t_set(result, self.modes[side]) + t_full = t_this_side + t_other_side + t_other_side = t_this_side + + # Store results + t_set_all[it] = t_this_side # Store transmission matrix + results_all[it] = result # Store result + results_latest[side] = result # Store latest result for this set + + # Try full pattern + if capture_intermediate_results: + self.slm.set_phases(-np.angle(t_full)) + intermediate_results[it] = self.feedback.read() + + # Compute average fidelity factors + fidelity_noise = weighted_average(results_latest[0].fidelity_noise, + results_latest[1].fidelity_noise, results_latest[0].n, + results_latest[1].n) + fidelity_amplitude = weighted_average(results_latest[0].fidelity_amplitude, + results_latest[1].fidelity_amplitude, results_latest[0].n, + results_latest[1].n) + fidelity_calibration = weighted_average(results_latest[0].fidelity_calibration, + results_latest[1].fidelity_calibration, results_latest[0].n, + results_latest[1].n) + + result = WFSResult(t=t_full, + t_f=None, + n=self.modes[0].shape[2] + self.modes[1].shape[2], + axis=2, + fidelity_noise=fidelity_noise, + fidelity_amplitude=fidelity_amplitude, + fidelity_calibration=fidelity_calibration) + + # TODO: document the t_set_all and results_all attributes + result.t_set_all = t_set_all + result.results_all = results_all + result.intermediate_results = intermediate_results + return result + + def _single_side_experiment(self, mod_phases: nd, ref_phases: nd, mod_mask: nd, + progress_bar=None) -> WFSResult: + """ + Conducts experiments on one part of the SLM. + + Args: + mod_phases: 3D array containing the phase patterns of each mode. Axis 0 and 1 are used as spatial axis. + Axis 2 is used for the 'phase pattern index' or 'mode index'. + ref_phases: 2D array containing the reference phase pattern. + mod_mask: 2D array containing a boolean mask, where True indicates the modulated part of the SLM. + progress_bar: Optional progress bar object. Following the convention for tqdm progress bars, + this object should have a `total` attribute and an `update()` function. + + Returns: + WFSResult: An object containing the computed SLM transmission matrix and related data. + """ + num_modes = mod_phases.shape[2] + measurements = np.zeros((num_modes, self.phase_steps)) + + for m in range(num_modes): + phases = ref_phases.copy() + modulated = mod_phases[:, :, m] + for p in range(self.phase_steps): + phi = p * 2 * np.pi / self.phase_steps + # set the modulated pixel values to the values corresponding to mode m and phase offset phi + phases[mod_mask] = modulated[mod_mask] + phi + self.slm.set_phases(phases) + self.feedback.trigger(out=measurements[m, p, ...]) + + if progress_bar is not None: + progress_bar.update() + + self.feedback.wait() + return self.analyzer(measurements, axis=1) + + @staticmethod + def compute_t_set(wfs_result: WFSResult, mode_set: nd) -> nd: + """ + Compute the transmission matrix in SLM space from transmission matrix in input mode space. + + Note 1: This function computes the transmission matrix for one mode set, and thus returns one part of the full + transmission matrix. The elements that are not part of the mode set will be 0. The full transmission matrix can + be obtained by simply adding the parts, i.e. t_full = t_set0 + t_set1. + + Note 2: As this is a blind focusing WFS algorithm, there may be only one target or 'output mode'. + + Args: + wfs_result (WFSResult): The result of the WFS algorithm. This contains the transmission matrix in the space + of input modes. + mode_set: 3D array with set of modes. + """ + t = wfs_result.t.squeeze().reshape((1, 1, mode_set.shape[2])) + norm_factor = np.prod(mode_set.shape[0:2]) + t_set = (t * mode_set).sum(axis=2) / norm_factor + return t_set diff --git a/openwfs/plot_utilities.py b/openwfs/plot_utilities.py index 1a9a6bd..e11d1e8 100644 --- a/openwfs/plot_utilities.py +++ b/openwfs/plot_utilities.py @@ -1,5 +1,11 @@ +from typing import Tuple + +import numpy as np +from numpy import ndarray as nd from astropy import units as u from matplotlib import pyplot as plt +from matplotlib.colors import hsv_to_rgb +from matplotlib.axes import Axes from .core import Detector from .utilities import get_extent @@ -48,3 +54,153 @@ def scale_prefix(value: u.Quantity) -> u.Quantity: return value.to(u.s) else: return value + + +def slope_step(a: nd, width: nd | float) -> nd: + """ + A sloped step function from 0 to 1. + + Args: + a: Input array + width: width of the sloped step. + + Returns: + An array the size of a, with the result of the sloped step function. + """ + return (a >= width) + a/width * (0 < a) * (a < width) + + +def linear_blend(a: nd, b: nd, blend: nd | float) -> nd: + """ + Return a linear, element-wise blend between two arrays a and b. + + Args: + a: Input array a. + b: Input array b. + blend: Blend factor. Value of 1.0 -> return a. Value of 0.0 -> return b. + + Returns: + A linear combination of a and b, corresponding to the blend factor. a*blend + b*(1-blend) + """ + return a*blend + b*(1-blend) + + +def complex_to_rgb(array: nd, scale: float | nd | None = None, axis: int = 2) -> nd: + """ + Generate RGB color values to represent values of a complex array. + + The complex values are mapped to HSV colorspace and then converted to RGB. Hue represents phase and Value represents + amplitude. Saturation is set to 1. + + Args: + array: Array to create RGB values for. + scale: Scaling factor for the array values. When None, scale = 1/max(abs(array)) is used. + axis: Array axis to use for the RGB dimension. + + Returns: + An RGB array representing the complex input array. + """ + if scale is None: + scale = 1 / np.max(abs(array)) + h = np.expand_dims(np.angle(array) / (2 * np.pi) + 0.5, axis=axis) + s = np.ones_like(h) + v = np.expand_dims(np.abs(array) * scale, axis=axis).clip(min=0, max=1) + hsv = np.concatenate((h, s, v), axis=axis) + rgb = hsv_to_rgb(hsv) + return rgb + + +def plot_field(array, scale: float | nd | None = None, imshow_kwargs: dict | None = None): + """ + Plot a complex array as an RGB image. + + The phase is represented by the hue, and the magnitude by the value, i.e. black = zero, brightness shows amplitude, + and the colors represent the phase. + + Args: + array(ndarray): complex array to be plotted. + scale(float): scaling factor for the magnitude. The final value is clipped to the range [0, 1]. + imshow_kwargs: Keyword arguments for matplotlib's imshow. + """ + if imshow_kwargs is None: + imshow_kwargs = {} + rgb = complex_to_rgb(array, scale) + plt.imshow(rgb, **imshow_kwargs) + + +def plot_scatter_field(x, y, array, scale, scatter_kwargs=None): + """ + Plot complex scattered data as RGB values. + """ + if scatter_kwargs is None: + scatter_kwargs = {'s': 80} + rgb = complex_to_rgb(array, scale, axis=1) + plt.scatter(x, y, c=rgb, **scatter_kwargs) + + +def complex_colorbar(scale, width_inverse: int = 15): + """ + Create an rgb colorbar for complex numbers and return its Axes handle. + """ + amp = np.linspace(0, 1.01, 10).reshape((1, -1)) + phase = np.linspace(0, 249 / 250 * 2 * np.pi, 250).reshape(-1, 1) - np.pi + z = amp * np.exp(1j * phase) + rgb = complex_to_rgb(z, 1) + ax = plt.subplot(1, width_inverse, width_inverse) + plt.imshow(rgb, aspect='auto', extent=(0, scale, -np.pi, np.pi)) + + # Ticks and labels + ax.set_yticks((-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi), ('$-\\pi$', '$-\\pi/2$', '0', '$\\pi/2$', '$\\pi$')) + ax.set_xlabel('amp.') + ax.set_ylabel('phase (rad)') + ax.yaxis.tick_right() + ax.yaxis.set_label_position("right") + return ax + + +def complex_colorwheel(ax: Axes = None, shape: Tuple[int, int] = (100, 100), imshow_kwargs: dict = {}, + arrow_props: dict = {}, text_kwargs: dict = {}, amplitude_str: str = 'A', + phase_str: str = '$\\phi$'): + """ + Create an rgb image for a colorwheel representing the complex unit circle. + + Args: + ax: Matplotlib Axes. + shape: Number of pixels in each dimension. + imshow_kwargs: Keyword arguments for matplotlib's imshow. + arrow_props: Keyword arguments for the arrows. + text_kwargs: Keyword arguments for the text labels. + amplitude_str: Text label for the amplitude arrow. + phase_str: Text label for the phase arrow. + + Returns: + rgb_wheel: rgb image of the colorwheel. + """ + if ax is None: + ax = plt.gca() + + x = np.linspace(-1, 1, shape[1]).reshape(1, -1) + y = np.linspace(-1, 1, shape[0]).reshape(-1, 1) + z = x + 1j*y + rgb = complex_to_rgb(z, scale=1) + step_width = 1.5 / shape[1] + blend = np.expand_dims(slope_step(1 - np.abs(z) - step_width, width=step_width), axis=2) + rgba_wheel = np.concatenate((rgb, blend), axis=2) + ax.imshow(rgba_wheel, extent=(-1, 1, -1, 1), **imshow_kwargs) + + # Add arrows with annotations + ax.annotate('', xy=(-0.98/np.sqrt(2),)*2, xytext=(0, 0), arrowprops={'color': 'white', 'width': 1.8, + 'headwidth': 5.0, 'headlength': 6.0, **arrow_props}) + ax.text(**{'x': -0.4, 'y': -0.8, 's': amplitude_str, 'color': 'white', 'fontsize': 15, **text_kwargs}) + ax.annotate('', xy=(0, 0.9), xytext=(0.9, 0), + arrowprops={'connectionstyle': 'arc3,rad=0.4', 'color': 'white', 'width': 1.8, 'headwidth': 5.0, + 'headlength': 6.0, **arrow_props}) + ax.text(**{'x': 0.1, 'y': 0.5, 's': phase_str, 'color': 'white', 'fontsize': 15, **text_kwargs}) + + # Hide axes spines and ticks + ax.set_xticks([]) + ax.set_yticks([]) + ax.spines['left'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.spines['top'].set_visible(False) + ax.spines['bottom'].set_visible(False) diff --git a/tests/test_wfs.py b/tests/test_wfs.py index f134287..f2ddce5 100644 --- a/tests/test_wfs.py +++ b/tests/test_wfs.py @@ -9,8 +9,9 @@ from ..openwfs.algorithms.troubleshoot import field_correlation from ..openwfs.algorithms.utilities import WFSController from ..openwfs.processors import SingleRoi -from ..openwfs.simulation import SimulatedWFS, StaticSource, SLM, Microscope from ..openwfs.simulation.mockdevices import GaussianNoise +from ..openwfs.simulation import SimulatedWFS, StaticSource, SLM, Microscope +from ..openwfs.plot_utilities import plot_field @pytest.mark.parametrize("shape", [(4, 7), (10, 7), (20, 31)]) @@ -436,15 +437,13 @@ def test_custom_blind_dual_reference_non_ortho(): """ Test custom blind dual reference with a non-orthogonal basis. """ - do_debug = False + do_debug = True # Create set of modes that are barely linearly independent N1 = 6 N2 = 3 M = N1 * N2 - mode_set_half = (1 / M) * ( - 1j * np.eye(M).reshape((N1, N2, M)) * -np.ones(shape=(N1, N2, M)) - ) + mode_set_half = (1 / M) * (1j * np.eye(M).reshape((N1, N2, M)) * -np.ones(shape=(N1, N2, M))) + (1/M) mode_set = np.concatenate((mode_set_half, np.zeros(shape=(N1, N2, M))), axis=1) phases_set = np.angle(mode_set) mask = np.concatenate((np.zeros((N1, N2)), np.ones((N1, N2))), axis=1) @@ -456,8 +455,8 @@ def test_custom_blind_dual_reference_non_ortho(): plt.figure(figsize=(12, 7)) for m in range(M): plt.subplot(N2, N1, m + 1) - plt.imshow(phases_set[:, :, m], vmin=-np.pi, vmax=np.pi) - plt.title(f"m={m}") + plot_field(mode_set[:, :, m]) + plt.title(f'm={m}') plt.xticks([]) plt.yticks([]) plt.pause(0.01) @@ -489,16 +488,13 @@ def test_custom_blind_dual_reference_non_ortho(): if do_debug: plt.figure() - plt.imshow( - np.angle(np.exp(1j * aberrations)), vmin=-np.pi, vmax=np.pi, cmap="hsv" - ) - plt.title("Aberrations") - plt.colorbar() + plt.subplot(1, 2, 1) + plot_field(np.exp(1j * aberrations)) + plt.title('Aberrations') - plt.figure() - plt.imshow(np.angle(result.t), vmin=-np.pi, vmax=np.pi, cmap="hsv") - plt.title("t") - plt.colorbar() + plt.subplot(1, 2, 2) + plot_field(result.t) + plt.title('t') plt.show() assert np.abs(field_correlation(np.exp(1j * aberrations), result.t)) > 0.999 From ae639300ba2d7d09725ec3b8ff8fed51ebdef26f Mon Sep 17 00:00:00 2001 From: Daniel Cox Date: Tue, 1 Oct 2024 17:33:15 +0200 Subject: [PATCH 09/12] Add generic cobasis computation. Works also for non-ortho basis --- openwfs/algorithms/dual_reference.py | 71 +++++++++++++++++++++++++--- 1 file changed, 64 insertions(+), 7 deletions(-) diff --git a/openwfs/algorithms/dual_reference.py b/openwfs/algorithms/dual_reference.py index 1f6fd92..7478c92 100644 --- a/openwfs/algorithms/dual_reference.py +++ b/openwfs/algorithms/dual_reference.py @@ -38,6 +38,7 @@ def __init__( feedback: Detector, slm: PhaseSLM, phase_patterns: Optional[tuple[nd, nd]], + amplitude: Optional[nd | str], group_mask: nd, phase_steps: int = 4, iterations: int = 2, @@ -52,6 +53,10 @@ def __init__( The first two dimensions are the spatial dimensions, and should match the size of group_mask. The 3rd dimension in the array is index of the phase pattern. The number of phase patterns in A and B may be different. When None, the phase_patterns attribute must be set before executing the algorithm. + amplitude: A 2D array, with shape equal to the shape of group_mask. + When None, the amplitude attribute must be set before executing the algorithm. + When 'ones', a 2D array where each element is 1.0 will be used. This corresponds to a uniform + illumination of the SLM. group_mask: A 2D bool array of that defines the pixels used by group A with False and elements used by group B with True. phase_steps: The number of phase steps for each mode (default is 4). Depending on the type of @@ -100,8 +105,32 @@ def __init__( self.masks = ( ~mask, mask, - ) # mask[0] is True for group A, mask[1] is True for group B + ) # self.masks[0] is True for group A, self.masks[1] is True for group B self.phase_patterns = phase_patterns + self._amplitude = None + self.amplitude = amplitude # Note: when 'ones' is passed, the shape of self.masks[0] is used. + + @property + def amplitude(self) -> Optional[nd]: + return self._amplitude + + @amplitude.setter + def amplitude(self, value): + if value is None: + self._amplitude = None + return + + if value == 'ones': + self._amplitude = np.ones(shape=self._shape, dtype=np.float32) + return + + if value.shape != self._shape: + raise ValueError( + "The amplitude and group mask must all have the same shape." + ) + + self._amplitude = value.astype(np.float32) + @property def phase_patterns(self) -> tuple[nd, nd]: @@ -143,6 +172,40 @@ def phase_patterns(self, value): value[1].astype(np.float32), ) + self._compute_cobasis() + + @property + def cobasis(self) -> tuple[nd, nd]: + return self._cobasis + + def _compute_cobasis(self): + """ + Computes the cobasis from the phase patterns. + + As a basis matrix is full rank, this is equivalent to the Moore-Penrose pseudo-inverse. + B⁺ = B* (B B*)^(-1) + Where B is the basis matrix (a row corresponds to a basis vector), * denotes the conjugate transpose, ^(-1) + denotes the matrix inverse, and ⁺ denotes the Moore-Penrose pseudo-inverse. + """ + if self.phase_patterns is None: + raise('The phase_patterns must be set before computing the cobasis.') + + self._cobasis = tuple( + np.exp(-1j * self.phase_patterns[side]) * np.expand_dims(self.amplitude * self.masks[side], axis=2) + for side in range(2) + ) + + cobasis = [None, None] + for side in range(2): + p = np.prod(self._shape) # Number of SLM pixels + m = self.phase_patterns[side].shape[2] # Number of modes + B = np.exp(1j * self.phase_patterns[side]) \ + * np.expand_dims(self.amplitude * self.masks[side], axis=2).reshape((p, m)) # Basis matrix + B_pinv = np.linalg.inv(B.conj() @ B) @ B.conj() # Moore-Penrose pseudo-inverse + cobasis[side] = B_pinv.reshape(self.phase_patterns[side].shape) + + self._cobasis = cobasis + def execute( self, capture_intermediate_results: bool = False, progress_bar=None ) -> WFSResult: @@ -161,12 +224,6 @@ def execute( """ # Current estimate of the transmission matrix (start with all 0) - cobasis = [ - np.exp(-1j * self.phase_patterns[side]) - * np.expand_dims(self.masks[side], axis=2) - for side in range(2) - ] - ref_phases = np.zeros(self._shape) # Initialize storage lists From 119e879ab68db8fea0c584f14a52451fdfd7968d Mon Sep 17 00:00:00 2001 From: Daniel Cox Date: Tue, 1 Oct 2024 17:36:14 +0200 Subject: [PATCH 10/12] Update doc cobasis --- openwfs/algorithms/dual_reference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openwfs/algorithms/dual_reference.py b/openwfs/algorithms/dual_reference.py index 7478c92..dbe462d 100644 --- a/openwfs/algorithms/dual_reference.py +++ b/openwfs/algorithms/dual_reference.py @@ -183,8 +183,8 @@ def _compute_cobasis(self): Computes the cobasis from the phase patterns. As a basis matrix is full rank, this is equivalent to the Moore-Penrose pseudo-inverse. - B⁺ = B* (B B*)^(-1) - Where B is the basis matrix (a row corresponds to a basis vector), * denotes the conjugate transpose, ^(-1) + B⁺ = (B^* B)^(-1) B^* + Where B is the basis matrix (a column corresponds to a basis vector), ^* denotes the conjugate transpose, ^(-1) denotes the matrix inverse, and ⁺ denotes the Moore-Penrose pseudo-inverse. """ if self.phase_patterns is None: From 7f5d768e87e60f8fcfd2f6564101daaeb07734a5 Mon Sep 17 00:00:00 2001 From: Daniel Cox Date: Wed, 2 Oct 2024 11:44:05 +0200 Subject: [PATCH 11/12] Fix cobasis computation, change 'ones' to 'uniform' amplitude --- openwfs/algorithms/dual_reference.py | 55 ++++++++++++++++------------ tests/test_wfs.py | 21 ++++++++--- 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/openwfs/algorithms/dual_reference.py b/openwfs/algorithms/dual_reference.py index dbe462d..409567b 100644 --- a/openwfs/algorithms/dual_reference.py +++ b/openwfs/algorithms/dual_reference.py @@ -38,7 +38,7 @@ def __init__( feedback: Detector, slm: PhaseSLM, phase_patterns: Optional[tuple[nd, nd]], - amplitude: Optional[nd | str], + amplitude: Optional[tuple[nd, nd] | str], group_mask: nd, phase_steps: int = 4, iterations: int = 2, @@ -53,10 +53,11 @@ def __init__( The first two dimensions are the spatial dimensions, and should match the size of group_mask. The 3rd dimension in the array is index of the phase pattern. The number of phase patterns in A and B may be different. When None, the phase_patterns attribute must be set before executing the algorithm. - amplitude: A 2D array, with shape equal to the shape of group_mask. - When None, the amplitude attribute must be set before executing the algorithm. - When 'ones', a 2D array where each element is 1.0 will be used. This corresponds to a uniform - illumination of the SLM. + amplitude: Tuple of 2D arrays, one array for each group. The arrays have shape equal to the shape of + group_mask. When None, the amplitude attribute must be set before executing the algorithm. When + 'uniform', a 2D array of normalized uniform values is used, such that ⟨A,A⟩=1, where ⟨.,.⟩ denotes the + inner product and A is the amplitude profile per group. This corresponds to a uniform illumination of + the SLM. Note: if the groups have different sizes, their normalization factors will be different. group_mask: A 2D bool array of that defines the pixels used by group A with False and elements used by group B with True. phase_steps: The number of phase steps for each mode (default is 4). Depending on the type of @@ -77,7 +78,6 @@ def __init__( [1]: X. Tao, T. Lam, B. Zhu, et al., “Three-dimensional focusing through scattering media using conjugate adaptive optics with remote focusing (CAORF),” Opt. Express 25, 10368–10383 (2017). - """ if optimized_reference is None: # 'auto' mode optimized_reference = np.prod(feedback.data_shape) == 1 @@ -100,15 +100,16 @@ def __init__( self.iterations = iterations self._analyzer = analyzer self._phase_patterns = None + self._amplitude = None + self._gram = None self._shape = group_mask.shape mask = group_mask.astype(bool) self.masks = ( ~mask, mask, ) # self.masks[0] is True for group A, self.masks[1] is True for group B + self.amplitude = amplitude # Note: when 'uniform' is passed, the shape of self.masks[0] is used. self.phase_patterns = phase_patterns - self._amplitude = None - self.amplitude = amplitude # Note: when 'ones' is passed, the shape of self.masks[0] is used. @property def amplitude(self) -> Optional[nd]: @@ -120,8 +121,9 @@ def amplitude(self, value): self._amplitude = None return - if value == 'ones': - self._amplitude = np.ones(shape=self._shape, dtype=np.float32) + if value == 'uniform': + self._amplitude = tuple( + (np.ones(shape=self._shape) / np.sqrt(self.masks[side].sum())).astype(np.float32) for side in range(2)) return if value.shape != self._shape: @@ -176,8 +178,18 @@ def phase_patterns(self, value): @property def cobasis(self) -> tuple[nd, nd]: + """ + The cobasis corresponding to the given basis. + """ return self._cobasis + @property + def gram(self) -> np.matrix: + """ + The Gram matrix corresponding to the given basis (i.e. phase pattern and amplitude profile). + """ + return self._gram + def _compute_cobasis(self): """ Computes the cobasis from the phase patterns. @@ -190,19 +202,16 @@ def _compute_cobasis(self): if self.phase_patterns is None: raise('The phase_patterns must be set before computing the cobasis.') - self._cobasis = tuple( - np.exp(-1j * self.phase_patterns[side]) * np.expand_dims(self.amplitude * self.masks[side], axis=2) - for side in range(2) - ) - cobasis = [None, None] for side in range(2): p = np.prod(self._shape) # Number of SLM pixels m = self.phase_patterns[side].shape[2] # Number of modes - B = np.exp(1j * self.phase_patterns[side]) \ - * np.expand_dims(self.amplitude * self.masks[side], axis=2).reshape((p, m)) # Basis matrix - B_pinv = np.linalg.inv(B.conj() @ B) @ B.conj() # Moore-Penrose pseudo-inverse - cobasis[side] = B_pinv.reshape(self.phase_patterns[side].shape) + phase_factor = np.exp(1j * self.phase_patterns[side]) + amplitude_factor = np.expand_dims(self.amplitude[side] * self.masks[side], axis=2) + B = np.asmatrix((phase_factor * amplitude_factor).reshape((p, m))) # Basis matrix + self._gram = B.H @ B + B_pinv = np.linalg.inv(self.gram) @ B.H # Moore-Penrose pseudo-inverse + cobasis[side] = np.asarray(B_pinv).reshape(self.phase_patterns[side].shape) self._cobasis = cobasis @@ -257,9 +266,7 @@ def execute( if self.optimized_reference: # use the best estimate so far to construct an optimized reference - t_this_side = self.compute_t_set( - results_all[it].t, cobasis[side] - ).squeeze() + t_this_side = self.compute_t_set(results_all[it].t, self.cobasis[side]).squeeze() ref_phases[self.masks[side]] = -np.angle(t_this_side[self.masks[side]]) # Try full pattern @@ -281,8 +288,8 @@ def execute( (1, *self.feedback.data_shape) ) - t_full = self.compute_t_set(results_all[0].t, cobasis[0]) + self.compute_t_set( - factor * results_all[1].t, cobasis[1] + t_full = self.compute_t_set(results_all[0].t, self.cobasis[0]) + self.compute_t_set( + factor * results_all[1].t, self.cobasis[1] ) # Compute average fidelity factors diff --git a/tests/test_wfs.py b/tests/test_wfs.py index f2ddce5..d2f8627 100644 --- a/tests/test_wfs.py +++ b/tests/test_wfs.py @@ -370,19 +370,21 @@ def test_multidimensional_feedback_fourier(): Expected at least 3.0, got {enhancement}""" -@pytest.mark.parametrize("type", ("plane_wave", "hadamard")) +@pytest.mark.parametrize("basis_str", ("plane_wave", "hadamard")) @pytest.mark.parametrize("shape", ((8, 8), (16, 4))) -def test_custom_blind_dual_reference_ortho_split(type: str, shape): +def test_custom_blind_dual_reference_ortho_split(basis_str: str, shape): """Test custom blind dual reference with an orthonormal phase-only basis. Two types of bases are tested: plane waves and Hadamard""" - do_debug = False + do_debug = True N = shape[0] * (shape[1] // 2) modes_shape = (shape[0], shape[1] // 2, N) - if type == "plane_wave": + if basis_str == "plane_wave": # Create a full plane wave basis for one half of the SLM. modes = np.fft.fft2(np.eye(N).reshape(modes_shape), axes=(0, 1)) - else: # type == 'hadamard': + elif basis_str == 'hadamard': modes = hadamard(N).reshape(modes_shape) + else: + raise f'Unknown type of basis "{basis_str}".' mask = np.concatenate( (np.zeros(modes_shape[0:2], dtype=bool), np.ones(modes_shape[0:2], dtype=bool)), @@ -397,7 +399,7 @@ def test_custom_blind_dual_reference_ortho_split(type: str, shape): plt.figure(figsize=(12, 7)) for m in range(N): - plt.subplot(*modes_shape[0:1], m + 1) + plt.subplot(*modes_shape[0:2], m + 1) plt.imshow(np.angle(mode_set[:, :, m]), vmin=-np.pi, vmax=np.pi) plt.title(f"m={m}") plt.xticks([]) @@ -411,10 +413,16 @@ def test_custom_blind_dual_reference_ortho_split(type: str, shape): feedback=sim, slm=sim.slm, phase_patterns=(phases_set, np.flip(phases_set, axis=1)), + amplitude='uniform', group_mask=mask, iterations=4, ) + assert np.allclose(alg.gram, np.eye(N), atol=1e-6) + + for m in range(N): + alg.cobasis + result = alg.execute() if do_debug: @@ -479,6 +487,7 @@ def test_custom_blind_dual_reference_non_ortho(): feedback=sim, slm=sim.slm, phase_patterns=(phases_set, np.flip(phases_set, axis=1)), + amplitude='ones', group_mask=mask, phase_steps=4, iterations=4, From 7727b33db0a48993d5aebb347233d82a25433729 Mon Sep 17 00:00:00 2001 From: Daniel Cox Date: Wed, 2 Oct 2024 14:34:49 +0200 Subject: [PATCH 12/12] Dual Reference working for non-orthonormal basis --- openwfs/algorithms/dual_reference.py | 2 +- tests/test_wfs.py | 61 +++++++++++++++++++--------- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/openwfs/algorithms/dual_reference.py b/openwfs/algorithms/dual_reference.py index 409567b..9110365 100644 --- a/openwfs/algorithms/dual_reference.py +++ b/openwfs/algorithms/dual_reference.py @@ -211,7 +211,7 @@ def _compute_cobasis(self): B = np.asmatrix((phase_factor * amplitude_factor).reshape((p, m))) # Basis matrix self._gram = B.H @ B B_pinv = np.linalg.inv(self.gram) @ B.H # Moore-Penrose pseudo-inverse - cobasis[side] = np.asarray(B_pinv).reshape(self.phase_patterns[side].shape) + cobasis[side] = np.asarray(B_pinv.T).reshape(self.phase_patterns[side].shape) self._cobasis = cobasis diff --git a/tests/test_wfs.py b/tests/test_wfs.py index d2f8627..3949b35 100644 --- a/tests/test_wfs.py +++ b/tests/test_wfs.py @@ -375,14 +375,14 @@ def test_multidimensional_feedback_fourier(): def test_custom_blind_dual_reference_ortho_split(basis_str: str, shape): """Test custom blind dual reference with an orthonormal phase-only basis. Two types of bases are tested: plane waves and Hadamard""" - do_debug = True + do_debug = False N = shape[0] * (shape[1] // 2) modes_shape = (shape[0], shape[1] // 2, N) if basis_str == "plane_wave": # Create a full plane wave basis for one half of the SLM. - modes = np.fft.fft2(np.eye(N).reshape(modes_shape), axes=(0, 1)) + modes = np.fft.fft2(np.eye(N).reshape(modes_shape), axes=(0, 1)) / np.sqrt(N) elif basis_str == 'hadamard': - modes = hadamard(N).reshape(modes_shape) + modes = hadamard(N).reshape(modes_shape) / np.sqrt(N) else: raise f'Unknown type of basis "{basis_str}".' @@ -400,11 +400,12 @@ def test_custom_blind_dual_reference_ortho_split(basis_str: str, shape): plt.figure(figsize=(12, 7)) for m in range(N): plt.subplot(*modes_shape[0:2], m + 1) - plt.imshow(np.angle(mode_set[:, :, m]), vmin=-np.pi, vmax=np.pi) + plot_field(mode_set[:, :, m]) plt.title(f"m={m}") plt.xticks([]) plt.yticks([]) - plt.pause(0.1) + plt.suptitle('Basis') + plt.pause(0.01) # Create aberrations sim = SimulatedWFS(t=random_transmission_matrix(shape)) @@ -418,14 +419,17 @@ def test_custom_blind_dual_reference_ortho_split(basis_str: str, shape): iterations=4, ) - assert np.allclose(alg.gram, np.eye(N), atol=1e-6) - - for m in range(N): - alg.cobasis - result = alg.execute() if do_debug: + plt.figure() + for m in range(N): + plt.subplot(*modes_shape[0:2], m + 1) + plot_field(alg.cobasis[0][:, :, m]) + plt.title(f'{m}') + plt.suptitle('Cobasis') + plt.pause(0.01) + plt.figure() plt.imshow(np.angle(sim.t), vmin=-np.pi, vmax=np.pi, cmap="hsv") plt.title("Aberrations") @@ -436,22 +440,25 @@ def test_custom_blind_dual_reference_ortho_split(basis_str: str, shape): plt.colorbar() plt.show() - assert ( - np.abs(field_correlation(sim.t, result.t)) > 0.99 - ) # todo: find out why this is not higher + # Checks for orthonormal bases + assert np.allclose(alg.gram, np.eye(N), atol=1e-6) # Gram matrix must be I + assert np.allclose(alg.cobasis[0], mode_set.conj(), atol=1e-6) # Cobasis vectors are just the complex conjugates + + # todo: find out why this is not higher + assert np.abs(field_correlation(sim.t, result.t)) > 0.95 def test_custom_blind_dual_reference_non_ortho(): """ Test custom blind dual reference with a non-orthogonal basis. """ - do_debug = True + do_debug = False # Create set of modes that are barely linearly independent N1 = 6 N2 = 3 M = N1 * N2 - mode_set_half = (1 / M) * (1j * np.eye(M).reshape((N1, N2, M)) * -np.ones(shape=(N1, N2, M))) + (1/M) + mode_set_half = np.exp(2j*np.pi/3 * np.eye(M).reshape((N1, N2, M))) / np.sqrt(M) mode_set = np.concatenate((mode_set_half, np.zeros(shape=(N1, N2, M))), axis=1) phases_set = np.angle(mode_set) mask = np.concatenate((np.zeros((N1, N2)), np.ones((N1, N2))), axis=1) @@ -487,7 +494,7 @@ def test_custom_blind_dual_reference_non_ortho(): feedback=sim, slm=sim.slm, phase_patterns=(phases_set, np.flip(phases_set, axis=1)), - amplitude='ones', + amplitude='uniform', group_mask=mask, phase_steps=4, iterations=4, @@ -495,15 +502,31 @@ def test_custom_blind_dual_reference_non_ortho(): result = alg.execute() + aberration_field = np.exp(1j * aberrations) + t_field = np.exp(1j * np.angle(result.t)) + if do_debug: + plt.figure() + for m in range(M): + plt.subplot(N2, N1, m + 1) + plot_field(alg.cobasis[0][:, :, m], scale=2) + plt.title(f'{m}') + plt.suptitle('Cobasis') + plt.pause(0.01) + + plt.figure() + plt.imshow(abs(alg.gram), vmin=0, vmax=1) + plt.title('Gram matrix abs values') + plt.colorbar() + plt.figure() plt.subplot(1, 2, 1) - plot_field(np.exp(1j * aberrations)) + plot_field(aberration_field) plt.title('Aberrations') plt.subplot(1, 2, 2) - plot_field(result.t) + plot_field(t_field) plt.title('t') plt.show() - assert np.abs(field_correlation(np.exp(1j * aberrations), result.t)) > 0.999 + assert np.abs(field_correlation(aberration_field, t_field)) > 0.999