diff --git a/src/chromatix/functional/propagation.py b/src/chromatix/functional/propagation.py index b5c24d1..88d4a2f 100644 --- a/src/chromatix/functional/propagation.py +++ b/src/chromatix/functional/propagation.py @@ -1,3 +1,4 @@ +import jax import jax.numpy as jnp from ..field import Field from einops import rearrange @@ -73,8 +74,12 @@ def transfer_propagate( # assert N_pad % 2 == 0, "Padding should be even." # Calculating propagator L = jnp.sqrt(jnp.complex64(field.spectrum * z / n)) # lengthscale L - f = jnp.fft.fftfreq(field.shape[1] + N_pad, d=field.dx.squeeze()) - fx, fy = rearrange(f, "h -> 1 h 1 1"), rearrange(f, "w -> 1 1 w 1") + # TODO(dd): This calculation could probably go into Field + f = [] + for d in range(field.dx.size): + f.append(jnp.fft.fftfreq(field.shape[1] + N_pad, d=field.dx[..., d].squeeze())) + f = jnp.stack(f, axis=-1) + fx, fy = rearrange(f, "h c -> 1 h 1 c"), rearrange(f, "w c -> 1 1 w c") phase = -jnp.pi * L**2 * (fx**2 + fy**2) # Propagating field @@ -93,9 +98,6 @@ def transfer_propagate( return field -# Exact transfer method - - def exact_propagate( field: Field, z: float, @@ -117,8 +119,11 @@ def exact_propagate( ConcretizationError will arise when traced!). """ # Calculating propagator - f = jnp.fft.fftfreq(field.shape[1] + N_pad, d=field.dx.squeeze()) - fx, fy = rearrange(f, "h -> 1 h 1 1"), rearrange(f, "w -> 1 1 w 1") + f = [] + for d in range(field.dx.size): + f.append(jnp.fft.fftfreq(field.shape[1] + N_pad, d=field.dx[..., d].squeeze())) + f = jnp.stack(f, axis=-1) + fx, fy = rearrange(f, "h c -> 1 h 1 c"), rearrange(f, "w c -> 1 1 w c") kernel = 1 - (field.spectrum / n) ** 2 * (fx**2 + fy**2) kernel = jnp.maximum(kernel, 0.0) # removing evanescent waves phase = 2 * jnp.pi * (z * n / field.spectrum) * jnp.sqrt(kernel)