Skip to content

Commit

Permalink
Multislice & object comparisons (#62)
Browse files Browse the repository at this point in the history
- Add support for even/odd split reconstructions
- Support object comparisons and Fourier Ring Correlation (FRC)
- Support loading and viewing multi-layer (multislice) objects including layer distance
- Relocate diffraction pattern loading to wizard in detectors view
- Simplify behavior of diffraction pattern crop widgets
- Add new mouse tools to image view: move, ruler, rectangle, line-cut
- Add option to view complex array intensity
- Add color legends for acyclic and cyclic quantities to image view
- Fix tike cost function plots
- Add memory usage monitor widget
- Add Fourier Zone Plate (FZP) presets for several Advanced Photon Source (APS) instruments
- Add space-deliminated scan position reader to support format used for NXSchool IC datasets
- Update interface to support PtychoNN v0.2
- Extract probe/object classes and refactor
  • Loading branch information
stevehenke authored Nov 17, 2023
1 parent ed84ed8 commit fa07fee
Show file tree
Hide file tree
Showing 118 changed files with 4,245 additions and 2,415 deletions.
21 changes: 18 additions & 3 deletions ptychodus/api/apparatus.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
from __future__ import annotations
from dataclasses import dataclass
from decimal import Decimal


@dataclass(frozen=True)
class FresnelZonePlate:
zonePlateDiameterInMeters: float
outermostZoneWidthInMeters: float
centralBeamstopDiameterInMeters: float

def focalLengthInMeters(self, centralWavelengthInMeters: float) -> float:
return self.zonePlateDiameterInMeters * self.outermostZoneWidthInMeters \
/ centralWavelengthInMeters


@dataclass(frozen=True)
class PixelGeometry:
widthInMeters: Decimal
heightInMeters: Decimal
widthInMeters: float
heightInMeters: float

@classmethod
def createNull(cls) -> PixelGeometry:
return cls(0., 0.)
5 changes: 2 additions & 3 deletions ptychodus/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from decimal import Decimal
from enum import Enum, auto
from pathlib import Path
from typing import overload, Any, Optional, TypeAlias, Union
Expand Down Expand Up @@ -83,12 +82,12 @@ class DiffractionMetadata:
numberOfPatternsPerArray: int
numberOfPatternsTotal: int
patternDataType: numpy.dtype[numpy.integer[Any]]
detectorDistanceInMeters: Optional[Decimal] = None
detectorDistanceInMeters: Optional[float] = None
detectorExtentInPixels: Optional[ImageExtent] = None
detectorPixelGeometry: Optional[PixelGeometry] = None
detectorBitDepth: Optional[int] = None
cropCenterInPixels: Optional[Array2D[int]] = None
probeEnergyInElectronVolts: Optional[Decimal] = None
probeEnergyInElectronVolts: Optional[float] = None
filePath: Optional[Path] = None

@classmethod
Expand Down
19 changes: 19 additions & 0 deletions ptychodus/api/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,31 @@ class Point2D(Generic[T]):
y: T


@dataclass(frozen=True)
class Line2D(Generic[T]):
begin: Point2D[T]
end: Point2D[T]

def lerp(self, alpha: T) -> Point2D[T]:
beta = 1 - alpha
x = beta * self.begin.x + alpha * self.end.x
y = beta * self.begin.y + alpha * self.end.y
return Point2D[T](x, y)


class Interval(Generic[T]):

def __init__(self, lower: T, upper: T) -> None:
self.lower: T = lower
self.upper: T = upper

@classmethod
def createProper(self, a: T, b: T) -> Interval[T]:
if b < a:
return Interval[T](b, a)
else:
return Interval[T](a, b)

@property
def isEmpty(self) -> bool:
return self.upper < self.lower
Expand Down
4 changes: 4 additions & 0 deletions ptychodus/api/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def __repr__(self) -> str:
class ScalarTransformation(ABC):
'''interface for real-valued transformations of a real array'''

@abstractmethod
def decorateText(self, text: str) -> str:
pass

@abstractmethod
def __call__(self, array: RealArrayType) -> RealArrayType:
'''returns the transformed input array'''
Expand Down
87 changes: 83 additions & 4 deletions ptychodus/api/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,89 @@
from .image import ImageExtent
from .scan import ScanPoint

ObjectArrayType: TypeAlias = numpy.typing.NDArray[numpy.complexfloating[Any, Any]]

# object point coordinates are conventionally in pixel units
ObjectPoint: TypeAlias = Point2D[float]
ObjectArrayType: TypeAlias = numpy.typing.NDArray[numpy.complexfloating[Any, Any]]


class Object:

def __init__(self, array: ObjectArrayType | None = None) -> None:
self._array = numpy.zeros((1, 0, 0), dtype=complex)
self._layerDistanceInMeters = [numpy.inf]
self._centerXInMeters = 0.
self._centerYInMeters = 0.

if array is not None:
self.setArray(array)

def copy(self) -> Object:
clone = Object()
clone._array = self._array.copy()
clone._layerDistanceInMeters = self._layerDistanceInMeters.copy()
clone._centerXInMeters = float(self._centerXInMeters)
clone._centerYInMeters = float(self._centerYInMeters)
return clone

def getArray(self) -> ObjectArrayType:
return self._array

def setArray(self, array: ObjectArrayType) -> None:
if not numpy.iscomplexobj(array):
raise TypeError('Object must be a complex-valued ndarray')

if array.ndim == 2:
self._array = array[numpy.newaxis, :, :]
elif array.ndim == 3:
self._array = array
else:
raise ValueError('Object must be 2- or 3-dimensional ndarray.')

numberOfAddedLayers = self._array.shape[-3] - len(self._layerDistanceInMeters)

if numberOfAddedLayers > 0:
self._layerDistanceInMeters[-1] = 0.
self._layerDistanceInMeters.extend([0.] * numberOfAddedLayers)
self._layerDistanceInMeters[-1] = numpy.inf

def getDataType(self) -> numpy.dtype:
return self._array.dtype

def getExtentInPixels(self) -> ImageExtent:
return ImageExtent(width=self._array.shape[-1], height=self._array.shape[-2])

def getSizeInBytes(self) -> int:
return self._array.nbytes

def getNumberOfLayers(self) -> int:
return self._array.shape[-3]

def getLayer(self, number: int) -> ObjectArrayType:
return self._array[number, :, :]

def getLayersFlattened(self) -> ObjectArrayType:
return numpy.prod(self._array, axis=-3)

def getLayerDistancesInMeters(self) -> Sequence[float]:
return self._layerDistanceInMeters

def getLayerDistanceInMeters(self, number: int) -> float:
return self._layerDistanceInMeters[number]

def setLayerDistanceInMeters(self, layer: int, distance: float) -> None:
if 0 <= layer and layer < self.getNumberOfLayers() - 1:
self._layerDistanceInMeters[layer] = distance

def getCenter(self) -> ScanPoint:
return ScanPoint(self._centerXInMeters, self._centerYInMeters)

def setCenter(self, center: ScanPoint) -> None:
self._centerXInMeters = center.x
self._centerYInMeters = center.y

def hasSameShape(self, other: Object) -> bool:
return (self._array.shape == other._array.shape
and self._layerDistanceInMeters == other._layerDistanceInMeters)


@dataclass(frozen=True)
Expand Down Expand Up @@ -159,14 +238,14 @@ def __call__(self, array: ObjectArrayType) -> ObjectArrayType:
class ObjectFileReader(ABC):

@abstractmethod
def read(self, filePath: Path) -> ObjectArrayType:
def read(self, filePath: Path) -> Object:
'''reads an object from file'''
pass


class ObjectFileWriter(ABC):

@abstractmethod
def write(self, filePath: Path, array: ObjectArrayType) -> None:
def write(self, filePath: Path, object_: Object) -> None:
'''writes an object to file'''
pass
96 changes: 95 additions & 1 deletion ptychodus/api/plot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from __future__ import annotations
from dataclasses import dataclass
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, TypeAlias

import numpy
import numpy.typing
import scipy.fft

from .apparatus import PixelGeometry

ComplexArrayType: TypeAlias = numpy.typing.NDArray[numpy.complexfloating[Any, Any]]
IntegerArrayType: TypeAlias = numpy.typing.NDArray[numpy.integer[Any]]


@dataclass(frozen=True)
Expand All @@ -27,3 +37,87 @@ class Plot2D:
@classmethod
def createNull(cls) -> Plot2D:
return cls(PlotAxis.createNull(), PlotAxis.createNull())


@dataclass(frozen=True)
class LineCut:
distanceInMeters: Sequence[float]
value: Sequence[float]
valueLabel: str


@dataclass(frozen=True)
class FourierRingCorrelation:
spatialFrequency_rm: Sequence[float]
correlation: Sequence[float]

@staticmethod
def _integrateRings(rings: IntegerArrayType, array: ComplexArrayType) -> ComplexArrayType:
total = numpy.zeros(numpy.max(rings) + 1, dtype=complex)

for index, value in zip(rings.flat, array.flat):
total[index] += value

return total

@classmethod
def calculate(cls, image1: ComplexArrayType, image2: ComplexArrayType,
pixelGeometry: PixelGeometry) -> FourierRingCorrelation:
'''
See: Joan Vila-Comamala, Ana Diaz, Manuel Guizar-Sicairos, Alexandre Mantion,
Cameron M. Kewish, Andreas Menzel, Oliver Bunk, and Christian David,
"Characterization of high-resolution diffractive X-ray optics by ptychographic
coherent diffractive imaging," Opt. Express 19, 21333-21344 (2011)
'''

if numpy.ndim(image1) != 2 or numpy.ndim(image2) != 2:
raise ValueError('Images must be 2D!')

if numpy.shape(image1) != numpy.shape(image2):
raise ValueError('Images must have same shape!')

# TODO subpixel image registration
# TODO remove phase offset and ramp
# TODO apply soft-edged mask
# TODO stats: SSNR, area under FRC curve, average SNR, etc.

x_rm = scipy.fft.fftfreq(image1.shape[-1], d=pixelGeometry.widthInMeters)
y_rm = scipy.fft.fftfreq(image1.shape[-2], d=pixelGeometry.heightInMeters)
radialBinSize_rm = max(x_rm[1], y_rm[1])

xx_rm, yy_rm = numpy.meshgrid(x_rm, y_rm)
rr_rm = numpy.hypot(xx_rm, yy_rm)

rings = numpy.divide(rr_rm, radialBinSize_rm).astype(int)
spatialFrequency_rm = numpy.arange(numpy.max(rings) + 1) * radialBinSize_rm

sf1 = scipy.fft.fft2(image1)
sf2 = scipy.fft.fft2(image2)

c11 = cls._integrateRings(rings, numpy.multiply(sf1, numpy.conj(sf1)))
c12 = cls._integrateRings(rings, numpy.multiply(sf1, numpy.conj(sf2)))
c22 = cls._integrateRings(rings, numpy.multiply(sf2, numpy.conj(sf2)))

correlation = numpy.absolute(c12) / numpy.sqrt(numpy.absolute(numpy.multiply(c11, c22)))

# TODO replace NaNs with interpolated values

rnyquist = numpy.min(image1.shape) // 2 + 1
return cls(spatialFrequency_rm[:rnyquist], correlation[:rnyquist])

def getResolutionInMeters(self, threshold: float) -> float:
# TODO threshold from bits
for freq_rm, frc in zip(self.spatialFrequency_rm, self.correlation):
if frc < threshold:
return 1. / freq_rm

return numpy.nan

def getPlot(self) -> Plot2D:
freqSeries = PlotSeries('freq', [1.e-9 * freq for freq in self.spatialFrequency_rm])
frcSeries = PlotSeries('frc', self.correlation)

return Plot2D(
axisX=PlotAxis('Spatial Frequency [1/nm]', [freqSeries]),
axisY=PlotAxis('Fourier Ring Correlation', [frcSeries]),
)
3 changes: 3 additions & 0 deletions ptychodus/api/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def __getitem__(self, name: str) -> PluginEntry[T]:

raise KeyError(f'Invalid plugin name \"{name}\"')

def __bool__(self) -> bool:
return bool(self._entryList)

def copy(self) -> PluginChooser[T]:
clone = PluginChooser[T]()
clone._entryList = self._entryList.copy()
Expand Down
Loading

0 comments on commit fa07fee

Please sign in to comment.