Skip to content

Commit

Permalink
more robust vacuum probe handling
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Oct 12, 2024
1 parent db3dbb9 commit 025d950
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 10 deletions.
14 changes: 12 additions & 2 deletions py4DSTEM/process/phase/direct_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,17 @@ def preprocess(
if plot_overlap_trotters:

f = fx**2 + fy**2
q_probe = self._reciprocal_sampling[0] * 20 / self.angular_sampling[0]

if self._semiangle_cutoff == np.inf:
alpha_probe = (
xp.sqrt(xp.sum(xp.abs(self._fourier_probe_initial)) / np.pi)
* self.angular_sampling[0]
)
else:
alpha_probe = self._semiangle_cutoff
q_probe = (
self._reciprocal_sampling[0] * alpha_probe / self.angular_sampling[0]
)

bf_inds = f[self._trotter_inds[0], self._trotter_inds[1]] < q_probe
low_ind_x = self._trotter_inds[0][bf_inds][0]
Expand Down Expand Up @@ -1167,7 +1177,7 @@ def unwrap_trotter_phase(complex_data, mask):
) / (m + 1) - aberrations_normalization[aperture_plus_solo]
aberrations_basis_minus[
ind, :n_minus, a0
] = aberrations_normalization[aperture_plus_solo] - (
] = aberrations_normalization[aperture_minus_solo] - (
alpha_minus[aperture_minus_solo] ** (m + 1)
* xp.sin(n * phi_minus[aperture_minus_solo])
/ (m + 1)
Expand Down
32 changes: 24 additions & 8 deletions py4DSTEM/process/phase/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_array_module(*args):
return np


from py4DSTEM.process.utils import get_CoM
from py4DSTEM.process.utils import get_CoM, get_shifted_ar
from py4DSTEM.process.utils.cross_correlate import align_and_shift_images
from py4DSTEM.process.utils.utils import electron_wavelength_angstrom

Expand Down Expand Up @@ -131,6 +131,7 @@ def __init__(
self._wavelength = electron_wavelength_angstrom(energy)
self._gpts = gpts
self._sampling = sampling
self._device = device

self._parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols)))

Expand Down Expand Up @@ -171,15 +172,19 @@ def evaluate_aperture(
semiangle_cutoff = self._semiangle_cutoff / 1000

if self._vacuum_probe_intensity is not None:
vacuum_probe_intensity = xp.asarray(
self._vacuum_probe_intensity, dtype=xp.float32
)
vacuum_probe_amplitude = xp.sqrt(xp.maximum(vacuum_probe_intensity, 0))
if self._force_spatial_frequencies is not None:
origin = np.unravel_index(alpha.argmin(), self._gpts)
vacuum_probe_amplitude = xp.roll(
vacuum_probe_amplitude, -np.array(origin), axis=(0, 1)
vacuum_probe_intensity = get_shifted_ar(
xp.asarray(self._vacuum_probe_intensity, dtype=xp.float32),
self._origin[0],
self._origin[1],
bilinear=True,
device=self._device,
)
else:
vacuum_probe_intensity = xp.asarray(
self._vacuum_probe_intensity, dtype=xp.float32
)
vacuum_probe_amplitude = xp.sqrt(xp.maximum(vacuum_probe_intensity, 0))
return vacuum_probe_amplitude

if self._semiangle_cutoff == xp.inf:
Expand Down Expand Up @@ -423,6 +428,17 @@ def get_spatial_frequencies(self):
kx, ky = self._force_spatial_frequencies
kx = xp.asarray(kx).astype(xp.float32)
ky = xp.asarray(ky).astype(xp.float32)

def find_zero_crossing(x):
n = x.shape[0]
y0, y1 = np.argsort(np.abs(x))[:2]
x0, x1 = x[y0], x[y1]
y = (y0 * x1 - y1 * x0) / (x1 - x0)
dy = np.mod(y + n / 2, n) - n / 2
return dy

self._origin = tuple(find_zero_crossing(k) for k in [kx, ky])

return kx, ky

def polar_coordinates(self, x, y):
Expand Down

0 comments on commit 025d950

Please sign in to comment.