Skip to content

Commit

Permalink
Merge remote-tracking branch 'remotes/origin/dd/dev' into hackathon_2024
Browse files Browse the repository at this point in the history
  • Loading branch information
diptodip committed Jun 10, 2024
2 parents cc0bf40 + cc3e217 commit 736d9e7
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 39 deletions.
17 changes: 11 additions & 6 deletions src/chromatix/elements/phase_masks.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from typing import Callable, Optional, Tuple, Union

import jax.numpy as jnp
from chex import Array, PRNGKey
from flax import linen as nn
from jax.scipy.ndimage import map_coordinates

from chromatix.field import Field
from chromatix.functional import wrap_phase, phase_change
from chromatix.utils import seidel_aberrations, zernike_aberrations
from chromatix.ops import quantize
from chromatix.elements.utils import register

from ..field import Field
from ..functional import phase_change, wrap_phase
from ..utils import seidel_aberrations, zernike_aberrations

__all__ = [
"PhaseMask",
"SpatialLightModulator",
Expand Down Expand Up @@ -110,6 +108,10 @@ class SpatialLightModulator(nn.Module):
spacing: The pitch of the SLM pixels.
phase_range: The phase range that the SLM can simulate, provided as
(min, max).
num_bits: The number of bits of precision the phase pixels should be
quantized to. Defaults to None, in which case no quantization is
applied. Otherwise, the phase will be quantized to have
``2.0 ** num_bits`` values within ``phase_range``.
interpolation_order: The order of interpolation for the SLM pixels to
the shape of the incoming ``Field``. Can be 0 or 1. Defaults to 0.
f: Focal length of the system's objective. Defaults to None.
Expand All @@ -121,6 +123,7 @@ class SpatialLightModulator(nn.Module):
shape: Tuple[int, int]
spacing: float
phase_range: Tuple[float, float]
num_bits: Optional[Union[int, float]] = None
interpolation_order: int = 0
f: Optional[float] = None
n: Optional[float] = None
Expand All @@ -146,6 +149,8 @@ def __call__(self, field: Field) -> Field:
phase.shape == self.shape
), "Provided phase shape should match provided SLM shape"
phase = wrap_phase(phase, self.phase_range)
if self.num_bits is not None:
phase = quantize(phase, 2.0**self.num_bits, range=self.phase_range)
field_pixel_grid = jnp.meshgrid(
jnp.linspace(0, self.shape[0] - 1, num=field.spatial_shape[0]) + 0.5,
jnp.linspace(0, self.shape[1] - 1, num=field.spatial_shape[1]) + 0.5,
Expand Down
8 changes: 4 additions & 4 deletions src/chromatix/elements/sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __call__(
self,
sensor_input: Union[Field, Array],
input_spacing: Optional[Union[float, Array]] = None,
resample: bool = True,
) -> Array:
"""
Resample the given ``sensor_input`` to the pixels of the sensor and
Expand All @@ -61,16 +62,15 @@ def __call__(
sensor_input: The incoming ``Field`` or intensity ``Array``.
input_spacing: The spacing of the input, only required if resampling
is required and the input is an ``Array``.
resample: Whether to perform resampling or not. Only matters if
``resampling_method`` is ``None``. Defaults to ``True``.
"""
if isinstance(sensor_input, Field):
# WARNING(dd): @copypaste(Microscope) Assumes that field has same
# spacing at all wavelengths when calculating intensity!
input_spacing = sensor_input.dx[..., 0, 0].squeeze()
input_spacing = jnp.atleast_1d(input_spacing)
# Only want to resample if the spacing does not match
if self.resampling_method is not None and jnp.any(
input_spacing != self.spacing
):
if resample and self.resampling_method is not None:
resample_fn = self.resample_fn
else:
resample_fn = None
Expand Down
16 changes: 11 additions & 5 deletions src/chromatix/elements/sources.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import Callable, Optional, Tuple, Union

import numpy as np
import flax.linen as nn
from chex import Array, PRNGKey

from chromatix.elements.utils import register

from ..field import Field
from ..functional.sources import (
from chromatix.field import Field
from chromatix.functional.sources import (
generic_field,
objective_point_source,
plane_wave,
Expand Down Expand Up @@ -41,6 +39,7 @@ class PointSource(nn.Module):
pupil: If provided, will be called on the field to apply a pupil.
scalar: Whether the result should be ``ScalarField`` (if True) or
``VectorField`` (if False). Defaults to True.
epsilon: Value added to denominators for numerical stability.
"""

shape: Tuple[int, int]
Expand All @@ -53,6 +52,7 @@ class PointSource(nn.Module):
amplitude: Union[float, Array, Callable[[PRNGKey], Array]] = 1.0
pupil: Optional[Callable[[Field], Field]] = None
scalar: bool = True
epsilon: float = np.finfo(np.float32).eps,

@nn.compact
def __call__(self) -> Field:
Expand All @@ -71,6 +71,7 @@ def __call__(self) -> Field:
amplitude,
self.pupil,
self.scalar,
self.epsilon
)


Expand All @@ -97,6 +98,8 @@ class ObjectivePointSource(nn.Module):
amplitude: The amplitude of the electric field. For ``ScalarField`` this
doesnt do anything, but it is required for ``VectorField`` to set
the polarization.
offset: The offset (y and x) in spatial coordinates of the point source.
Defaults to (0, 0) for no offset (a centered point source).
scalar: Whether the result should be ``ScalarField`` (if True) or
``VectorField`` (if False). Defaults to True.
"""
Expand All @@ -110,6 +113,7 @@ class ObjectivePointSource(nn.Module):
NA: Union[float, Callable[[PRNGKey], float]]
power: Union[float, Callable[[PRNGKey], float]] = 1.0
amplitude: Union[float, Array, Callable[[PRNGKey], Array]] = 1.0
offset: Union[Array, Tuple[float, float]] = (0.0, 0.0)
scalar: bool = True

@nn.compact
Expand All @@ -119,6 +123,7 @@ def __call__(self, z: float) -> Field:
NA = register(self, "NA")
power = register(self, "power")
amplitude = register(self, "amplitude")
offset = register(self, "offset")

return objective_point_source(
self.shape,
Expand All @@ -131,6 +136,7 @@ def __call__(self, z: float) -> Field:
NA,
power,
amplitude,
offset,
self.scalar,
)

Expand Down
14 changes: 9 additions & 5 deletions src/chromatix/functional/phase_masks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Tuple

import jax
import jax.numpy as jnp
from chex import Array, assert_rank

from ..field import Field
from ..utils.shapes import _broadcast_2d_to_spatial
from chromatix.field import Field
from chromatix.utils.shapes import _broadcast_2d_to_spatial

__all__ = ["phase_change", "wrap_phase", "spectrally_modulate_phase"]

Expand All @@ -30,6 +29,7 @@ def phase_change(field: Field, phase: Array, spectrally_modulate: bool = True) -
return field * jnp.exp(1j * phase)


@jax.custom_jvp
def wrap_phase(phase: Array, limits: Tuple[float, float] = (-jnp.pi, jnp.pi)) -> Array:
"""
Wraps values of ``phase`` to the range given by ``limits``.
Expand All @@ -40,7 +40,6 @@ def wrap_phase(phase: Array, limits: Tuple[float, float] = (-jnp.pi, jnp.pi)) ->
will be wrapped to.
"""
phase_min, phase_max = limits
assert phase_min < phase_max, "Lower limit needs to be smaller than upper limit."
phase = jnp.where(
phase < phase_min,
phase + 2 * jnp.pi * (1 + (phase_min - phase) // (2 * jnp.pi)),
Expand All @@ -54,6 +53,11 @@ def wrap_phase(phase: Array, limits: Tuple[float, float] = (-jnp.pi, jnp.pi)) ->
return phase


@wrap_phase.defjvp
def wrap_phase_jvp(primals: Tuple, tangents: Tuple) -> Tuple:
return wrap_phase(*primals), tangents[0]


def spectrally_modulate_phase(phase: Array, field: Field) -> Array:
"""Spectrally modulates a given ``phase`` for multiple wavelengths."""
central_wavelength = field.spectrum[..., 0, 0].squeeze()
Expand Down
3 changes: 1 addition & 2 deletions src/chromatix/functional/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,7 @@ def compute_transfer_propagator(
"""
kykx = _broadcast_1d_to_grid(kykx, field.ndim)
z = _broadcast_1d_to_innermost_batch(z, field.ndim)
L = jnp.sqrt(jnp.complex64(field.spectrum * z / n)) # lengthscale L
phase = -jnp.pi * jnp.abs(L) ** 2 * l2_sq_norm(field.k_grid - kykx)
phase = -jnp.pi * (field.spectrum / n) * z * l2_sq_norm(field.k_grid - kykx)
return jnp.fft.ifftshift(jnp.exp(1j * phase), axes=field.spatial_dims)


Expand Down
27 changes: 15 additions & 12 deletions src/chromatix/functional/sources.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Callable, Optional, Tuple, Union

import numpy as np
import jax.numpy as jnp
from chex import Array, assert_axis_dimension, assert_equal_shape

from ..field import Field, ScalarField, VectorField
from ..utils import l2_sq_norm
from ..utils.shapes import (
from chromatix.field import Field, ScalarField, VectorField
from chromatix.utils import l2_sq_norm
from chromatix.utils.shapes import (
_broadcast_1d_to_grid,
_broadcast_1d_to_innermost_batch,
_broadcast_1d_to_polarization,
Expand All @@ -31,6 +30,7 @@ def point_source(
amplitude: Union[float, Array] = 1.0,
pupil: Optional[Callable[[ScalarField], ScalarField]] = None,
scalar: bool = True,
epsilon: float = np.finfo(np.float32).eps,
) -> Field:
"""
Generates field due to point source a distance ``z`` away.
Expand All @@ -53,16 +53,15 @@ def point_source(
pupil: If provided, will be called on the field to apply a pupil.
scalar: Whether the result should be ``ScalarField`` (if True) or
``VectorField`` (if False). Defaults to True.
epsilon: Value added to denominators for numerical stability.
"""
create = ScalarField.create if scalar else VectorField.create
field = create(dx, spectrum, spectral_density, shape=shape)
z = _broadcast_1d_to_innermost_batch(z, field.ndim)
amplitude = _broadcast_1d_to_polarization(amplitude, field.ndim)
L = jnp.sqrt(
field.spectrum * jnp.abs(z) / n
) # the abs are to allow for negative z. Note that this does not lead to a conjugation for a point source
phase = jnp.pi * l2_sq_norm(field.grid) / L**2
u = amplitude * -1j / L**2 * jnp.exp(1j * phase)
L = jnp.sqrt(jnp.complex64(field.spectrum * z / n))
phase = jnp.pi * l2_sq_norm(field.grid) / (L**2 + epsilon)
u = amplitude * -1j / (L**2 + epsilon) * jnp.exp(1j * phase)
field = field.replace(u=u)
if pupil is not None:
field = pupil(field)
Expand All @@ -80,6 +79,7 @@ def objective_point_source(
NA: float,
power: float = 1.0,
amplitude: Union[float, Array] = 1.0,
offset: Union[Array, Tuple[float, float]] = (0.0, 0.0),
scalar: bool = True,
) -> Field:
"""
Expand All @@ -102,15 +102,18 @@ def objective_point_source(
amplitude: The amplitude of the electric field. For ``ScalarField`` this
doesnt do anything, but it is required for ``VectorField`` to set
the polarization.
offset: The offset of the point source in the plane. Should be an array
of shape `[2,]` in the format `[y, x]`.
scalar: Whether the result should be ``ScalarField`` (if True) or
``VectorField`` (if False). Defaults to True.
"""
create = ScalarField.create if scalar else VectorField.create
field = create(dx, spectrum, spectral_density, shape=shape)
z = _broadcast_1d_to_innermost_batch(z, field.ndim)
amplitude = _broadcast_1d_to_polarization(amplitude, field.ndim)
offset = _broadcast_1d_to_grid(offset, field.ndim)
L = jnp.sqrt(field.spectrum * f / n)
phase = -jnp.pi * (z / f) * l2_sq_norm(field.grid) / L**2
phase = -jnp.pi * (z / f) * l2_sq_norm(field.grid - offset) / L**2
u = amplitude * -1j / L**2 * jnp.exp(1j * phase)
field = field.replace(u=u)
D = 2 * f * NA / n
Expand Down Expand Up @@ -147,7 +150,7 @@ def plane_wave(
doesnt do anything, but it is required for ``VectorField`` to set
the polarization.
kykx: Defines the orientation of the plane wave. Should be an
array of shape `[2,]` in the format [ky, kx].
array of shape `[2,]` in the format `[ky, kx]`.
pupil: If provided, will be called on the field to apply a pupil.
scalar: Whether the result should be ``ScalarField`` (if True) or
``VectorField`` (if False). Defaults to True.
Expand Down
13 changes: 8 additions & 5 deletions src/chromatix/systems/microscopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class Microscope(nn.Module):
smoothly bring the edges of the PSF to 0. This helps to prevent
edge artifacts in the image if the PSF has edge artifacts. Defaults
to 0, in which case no tapering is applied.
fast_fft_shape: If `True`, Fourier convolutions will be computed at
potentially larger shapes to gain speed at the expense of increased
memory requirements. If you are running out of memory, try setting
this to `False`. Defaults to `True`.
"""

system_psf: Callable[[Microscope], Union[Field, Array]]
Expand All @@ -76,6 +80,7 @@ class Microscope(nn.Module):
spectral_density: Array
padding_ratio: float = 0
taper_width: float = 0
fast_fft_shape: bool = True

def __call__(self, sample: Array, *args: Any, **kwargs: Any) -> Array:
"""
Expand Down Expand Up @@ -151,13 +156,11 @@ def image(self, sample: Array, psf: Array, axes: Tuple[int, int] = (1, 2)) -> Ar
sample: The sample volume to image with of shape `(B... H W 1 1)`.
psf: The PSF intensity volume to image with of shape `(B... H W 1 1)`.
"""
image = fourier_convolution(psf, sample, axes=axes)
image = fourier_convolution(sample, psf, axes=axes, fast_fft_shape=self.fast_fft_shape)
# NOTE(dd): By this point, the image should already be at the same
# spacing as the sensor. Any resampling to the pixels of the sensor
# should already have happened to the PSF. The intent of passing
# the sensor spacing as the input spacing is to bypass any further
# resampling.
image = self.sensor(image, self.sensor.spacing)
# should already have happened to the PSF.
image = self.sensor(image, self.sensor.spacing, resample=False)
return image


Expand Down

0 comments on commit 736d9e7

Please sign in to comment.