diff --git a/py4DSTEM/process/phase/direct_ptychography.py b/py4DSTEM/process/phase/direct_ptychography.py index 2a8200c67..0cfeb064a 100644 --- a/py4DSTEM/process/phase/direct_ptychography.py +++ b/py4DSTEM/process/phase/direct_ptychography.py @@ -1461,7 +1461,6 @@ def _reconstruct_single_frequency( probe_conj, aperture, probe_kwargs, - trotter_sign, phase_compensation: bool = True, virtual_detector_masks: Sequence[np.ndarray] = None, xp=np, @@ -1513,35 +1512,19 @@ def _reconstruct_single_frequency( gamma_ind = gamma_abs > threshold normalization = gamma_abs[gamma_ind] - if trotter_sign == "+": - numerator = -G[gamma_ind].conj() * gamma[gamma_ind] - elif trotter_sign == "-": - numerator = G[gamma_ind] * gamma[gamma_ind].conj() - else: - raise ValueError() - - return (numerator / normalization).sum() + return (G[gamma_ind] * gamma[gamma_ind].conj() / normalization).sum() else: aperture_plus = aperture_plus > threshold aperture_minus = aperture_minus > threshold - if trotter_sign == "+": - aperture_solo = xp.logical_and( - xp.logical_and(aperture, aperture_plus), ~aperture_minus - ) - return G[aperture_solo].sum().conj() - elif trotter_sign == "-": - aperture_solo = xp.logical_and( - xp.logical_and(aperture, aperture_minus), ~aperture_plus - ) - return G[aperture_solo].sum() - else: - raise ValueError() + aperture_solo = xp.logical_and( + xp.logical_and(aperture, aperture_minus), ~aperture_plus + ) + return G[aperture_solo].sum() def reconstruct( self, - trotter_sign="-", phase_compensation=True, num_jobs=None, threads_per_job=None, @@ -1557,8 +1540,6 @@ def reconstruct( Parameters -------- - trotter_sign: str, optional - Sign of single-side trotter to use. One of '+','-'. phase_compensation: bool, optional If True, the measured phase is compensated using a complex virtual detector. Recommnended. num_jobs: int, optional @@ -1630,7 +1611,6 @@ def reconstruct( probe_conj, aperture, probe_kwargs, - trotter_sign=trotter_sign, phase_compensation=phase_compensation, virtual_detector_masks=virtual_detector_masks, xp=xp, @@ -1666,7 +1646,6 @@ def wrapper_function(**kwargs): probe_conj=probe_conj, aperture=aperture, probe_kwargs=probe_kwargs, - trotter_sign=trotter_sign, phase_compensation=phase_compensation, virtual_detector_masks=virtual_detector_masks, xp=xp, @@ -1707,7 +1686,6 @@ def _reconstruct_single_frequency( probe_conj, probe_normalization, probe_kwargs, - trotter_sign, virtual_detector_masks: Sequence[np.ndarray] = None, xp=np, ): @@ -1748,18 +1726,10 @@ def _reconstruct_single_frequency( d = probe_normalization[gamma_ind] normalization = d * xp.sqrt(xp.sum(normalization**2 / d)) - if trotter_sign == "+": - numerator = -G[gamma_ind].conj() * gamma[gamma_ind] - elif trotter_sign == "-": - numerator = G[gamma_ind] * gamma[gamma_ind].conj() - else: - raise ValueError() - - return (numerator / normalization).sum() + return (G[gamma_ind] * gamma[gamma_ind].conj() / normalization).sum() def reconstruct( self, - trotter_sign="-", num_jobs=None, threads_per_job=None, virtual_detector_masks: Sequence[np.ndarray] = None, @@ -1774,8 +1744,6 @@ def reconstruct( Parameters -------- - trotter_sign: str, optional - Sign of single-side trotter to use. One of '+','-'. num_jobs: int, optional Number of processes to use. Default is None, which spawns as many processes as CPUs on the system. @@ -1849,7 +1817,6 @@ def reconstruct( probe_conj, probe_normalization, probe_kwargs, - trotter_sign=trotter_sign, virtual_detector_masks=virtual_detector_masks, xp=xp, ) @@ -1884,7 +1851,6 @@ def wrapper_function(**kwargs): probe_conj=probe_conj, probe_normalization=probe_normalization, probe_kwargs=probe_kwargs, - trotter_sign=trotter_sign, virtual_detector_masks=virtual_detector_masks, xp=xp, ) @@ -1923,7 +1889,6 @@ def _reconstruct_single_frequency( probe, epsilon, probe_kwargs, - trotter_sign, xp=np, ): """ """ @@ -1931,35 +1896,25 @@ def _reconstruct_single_frequency( array_G = xp.asarray(intensities_FFT) array_H = xp.fft.ifft2(array_G) - if trotter_sign == "+": - Kx_Qx = Kx + Qx - Ky_Qy = Ky + Qy - elif trotter_sign == "-": - Kx_Qx = Kx - Qx - Ky_Qy = Ky - Qy - else: - raise ValueError() + Kx_Qx = Kx - Qx + Ky_Qy = Ky - Qy probe_shifted = ComplexProbe( **probe_kwargs, force_spatial_frequencies=(Kx_Qx, Ky_Qy), )._evaluate_ctf() - if trotter_sign == "+": - wdd_probe = xp.fft.ifft2(probe * probe_shifted.conj()) - else: - wdd_probe = xp.fft.ifft2(probe.conj() * probe_shifted) + wdd_probe = xp.fft.ifft2(probe.conj() * probe_shifted) wdd_probe_conj = wdd_probe.conj() array_D = wdd_probe_conj * array_H / (wdd_probe * wdd_probe_conj + epsilon) array_D_FFT = xp.fft.fft2(array_D) - return array_D_FFT[0, 0].conj() if trotter_sign == "+" else array_D_FFT[0, 0] + return array_D_FFT[0, 0] def reconstruct( self, relative_wiener_epsilon, - trotter_sign="-", num_jobs=None, threads_per_job=None, virtual_detector_masks: Sequence[np.ndarray] = None, @@ -1974,8 +1929,6 @@ def reconstruct( Parameters -------- - trotter_sign: str, optional - Sign of single-side trotter to use. One of '+','-'. worker_pool: WorkerPool If not None, reconstruction is dispatched to mpire WorkerPool instance. virtual_detector_masks: np.ndarray @@ -2042,7 +1995,6 @@ def reconstruct( self._fourier_probe, epsilon, probe_kwargs, - trotter_sign=trotter_sign, xp=xp, ) else: @@ -2075,7 +2027,6 @@ def wrapper_function(**kwargs): probe=self._fourier_probe, epsilon=epsilon, probe_kwargs=probe_kwargs, - trotter_sign=trotter_sign, xp=xp, )