Skip to content

Commit

Permalink
remove trotter_sign -- correct is negative
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Oct 20, 2024
1 parent 60f9e94 commit b697818
Showing 1 changed file with 10 additions and 59 deletions.
69 changes: 10 additions & 59 deletions py4DSTEM/process/phase/direct_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,7 +1461,6 @@ def _reconstruct_single_frequency(
probe_conj,
aperture,
probe_kwargs,
trotter_sign,
phase_compensation: bool = True,
virtual_detector_masks: Sequence[np.ndarray] = None,
xp=np,
Expand Down Expand Up @@ -1513,35 +1512,19 @@ def _reconstruct_single_frequency(
gamma_ind = gamma_abs > threshold
normalization = gamma_abs[gamma_ind]

if trotter_sign == "+":
numerator = -G[gamma_ind].conj() * gamma[gamma_ind]
elif trotter_sign == "-":
numerator = G[gamma_ind] * gamma[gamma_ind].conj()
else:
raise ValueError()

return (numerator / normalization).sum()
return (G[gamma_ind] * gamma[gamma_ind].conj() / normalization).sum()

else:
aperture_plus = aperture_plus > threshold
aperture_minus = aperture_minus > threshold

if trotter_sign == "+":
aperture_solo = xp.logical_and(
xp.logical_and(aperture, aperture_plus), ~aperture_minus
)
return G[aperture_solo].sum().conj()
elif trotter_sign == "-":
aperture_solo = xp.logical_and(
xp.logical_and(aperture, aperture_minus), ~aperture_plus
)
return G[aperture_solo].sum()
else:
raise ValueError()
aperture_solo = xp.logical_and(
xp.logical_and(aperture, aperture_minus), ~aperture_plus
)
return G[aperture_solo].sum()

def reconstruct(
self,
trotter_sign="-",
phase_compensation=True,
num_jobs=None,
threads_per_job=None,
Expand All @@ -1557,8 +1540,6 @@ def reconstruct(
Parameters
--------
trotter_sign: str, optional
Sign of single-side trotter to use. One of '+','-'.
phase_compensation: bool, optional
If True, the measured phase is compensated using a complex virtual detector. Recommnended.
num_jobs: int, optional
Expand Down Expand Up @@ -1630,7 +1611,6 @@ def reconstruct(
probe_conj,
aperture,
probe_kwargs,
trotter_sign=trotter_sign,
phase_compensation=phase_compensation,
virtual_detector_masks=virtual_detector_masks,
xp=xp,
Expand Down Expand Up @@ -1666,7 +1646,6 @@ def wrapper_function(**kwargs):
probe_conj=probe_conj,
aperture=aperture,
probe_kwargs=probe_kwargs,
trotter_sign=trotter_sign,
phase_compensation=phase_compensation,
virtual_detector_masks=virtual_detector_masks,
xp=xp,
Expand Down Expand Up @@ -1707,7 +1686,6 @@ def _reconstruct_single_frequency(
probe_conj,
probe_normalization,
probe_kwargs,
trotter_sign,
virtual_detector_masks: Sequence[np.ndarray] = None,
xp=np,
):
Expand Down Expand Up @@ -1748,18 +1726,10 @@ def _reconstruct_single_frequency(
d = probe_normalization[gamma_ind]
normalization = d * xp.sqrt(xp.sum(normalization**2 / d))

if trotter_sign == "+":
numerator = -G[gamma_ind].conj() * gamma[gamma_ind]
elif trotter_sign == "-":
numerator = G[gamma_ind] * gamma[gamma_ind].conj()
else:
raise ValueError()

return (numerator / normalization).sum()
return (G[gamma_ind] * gamma[gamma_ind].conj() / normalization).sum()

def reconstruct(
self,
trotter_sign="-",
num_jobs=None,
threads_per_job=None,
virtual_detector_masks: Sequence[np.ndarray] = None,
Expand All @@ -1774,8 +1744,6 @@ def reconstruct(
Parameters
--------
trotter_sign: str, optional
Sign of single-side trotter to use. One of '+','-'.
num_jobs: int, optional
Number of processes to use. Default is None, which spawns as many processes as CPUs on
the system.
Expand Down Expand Up @@ -1849,7 +1817,6 @@ def reconstruct(
probe_conj,
probe_normalization,
probe_kwargs,
trotter_sign=trotter_sign,
virtual_detector_masks=virtual_detector_masks,
xp=xp,
)
Expand Down Expand Up @@ -1884,7 +1851,6 @@ def wrapper_function(**kwargs):
probe_conj=probe_conj,
probe_normalization=probe_normalization,
probe_kwargs=probe_kwargs,
trotter_sign=trotter_sign,
virtual_detector_masks=virtual_detector_masks,
xp=xp,
)
Expand Down Expand Up @@ -1923,43 +1889,32 @@ def _reconstruct_single_frequency(
probe,
epsilon,
probe_kwargs,
trotter_sign,
xp=np,
):
""" """

array_G = xp.asarray(intensities_FFT)
array_H = xp.fft.ifft2(array_G)

if trotter_sign == "+":
Kx_Qx = Kx + Qx
Ky_Qy = Ky + Qy
elif trotter_sign == "-":
Kx_Qx = Kx - Qx
Ky_Qy = Ky - Qy
else:
raise ValueError()
Kx_Qx = Kx - Qx
Ky_Qy = Ky - Qy

probe_shifted = ComplexProbe(
**probe_kwargs,
force_spatial_frequencies=(Kx_Qx, Ky_Qy),
)._evaluate_ctf()

if trotter_sign == "+":
wdd_probe = xp.fft.ifft2(probe * probe_shifted.conj())
else:
wdd_probe = xp.fft.ifft2(probe.conj() * probe_shifted)
wdd_probe = xp.fft.ifft2(probe.conj() * probe_shifted)
wdd_probe_conj = wdd_probe.conj()

array_D = wdd_probe_conj * array_H / (wdd_probe * wdd_probe_conj + epsilon)
array_D_FFT = xp.fft.fft2(array_D)

return array_D_FFT[0, 0].conj() if trotter_sign == "+" else array_D_FFT[0, 0]
return array_D_FFT[0, 0]

def reconstruct(
self,
relative_wiener_epsilon,
trotter_sign="-",
num_jobs=None,
threads_per_job=None,
virtual_detector_masks: Sequence[np.ndarray] = None,
Expand All @@ -1974,8 +1929,6 @@ def reconstruct(
Parameters
--------
trotter_sign: str, optional
Sign of single-side trotter to use. One of '+','-'.
worker_pool: WorkerPool
If not None, reconstruction is dispatched to mpire WorkerPool instance.
virtual_detector_masks: np.ndarray
Expand Down Expand Up @@ -2042,7 +1995,6 @@ def reconstruct(
self._fourier_probe,
epsilon,
probe_kwargs,
trotter_sign=trotter_sign,
xp=xp,
)
else:
Expand Down Expand Up @@ -2075,7 +2027,6 @@ def wrapper_function(**kwargs):
probe=self._fourier_probe,
epsilon=epsilon,
probe_kwargs=probe_kwargs,
trotter_sign=trotter_sign,
xp=xp,
)

Expand Down

0 comments on commit b697818

Please sign in to comment.