diff --git a/pyproject.toml b/pyproject.toml index 2fb4c9eb..7c856aee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "tqdm >= 4.41.1", "virtualbox >= 2.0.0", "pyserialem >= 0.3.2", + "diffpy.structure", ] [project.urls] diff --git a/src/instamatic/simulation/__init__.py b/src/instamatic/simulation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/instamatic/simulation/camera.py b/src/instamatic/simulation/camera.py new file mode 100644 index 00000000..29ccd2f8 --- /dev/null +++ b/src/instamatic/simulation/camera.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from typing import Tuple + +from numpy import ndarray + +from instamatic.camera.camera_base import CameraBase +from instamatic.simulation.stage import Stage + + +class CameraSimulation(CameraBase): + streamable = True + + def __init__(self, name: str = 'simulate'): + super().__init__(name) + + self.ready = False + + # TODO put parameters into config + self.stage = Stage() + self.mag = None + + def establish_connection(self): + pass + + def actually_establish_connection(self): + if self.ready: + return + import time + + time.sleep(2) + from instamatic.controller import get_instance + + ctrl = get_instance() + self.tem = ctrl.tem + + ctrl.stage.set(z=0, a=0, b=0) + print(self.tem.getStagePosition()) + print(self.stage.samples[0].x, self.stage.samples[0].y) + + self.ready = True + + def release_connection(self): + self.tem = None + self.ready = False + + def get_image(self, exposure: float = None, binsize: int = None, **kwargs) -> ndarray: + self.actually_establish_connection() + + if exposure is None: + exposure = self.default_exposure + if binsize is None: + binsize = self.default_binsize + + # TODO this has inconsistent units. Assume m, deg + pos = self.tem.getStagePosition() + if pos is not None and len(pos) == 5: + x, y, z, alpha, beta = pos + self.stage.set_position(x=x, y=y, z=z, alpha_tilt=alpha, beta_tilt=beta) + + mode = self.tem.getFunctionMode() + + # Get real-space extent + if mode == 'diff': + # TODO this has inconsistent units. Assume mm + self.camera_length = self.tem.getMagnification() + else: + mag = self.tem.getMagnification() + if isinstance(mag, (float, int)): + self.mag = mag + else: + print(mag, type(mag)) + if self.mag is None: + raise ValueError('Must start in image mode') + + # TODO consider beam shift, tilt ect. + x_min, x_max, y_min, y_max = self._mag_to_ranges(self.mag) + x_min += self.stage.x + x_max += self.stage.x + y_min += self.stage.y + y_max += self.stage.y + + # TODO I mean properly considering them, this has no regard for units ect + bx, by = self.tem.getBeamShift() + x_min += bx + x_max += bx + y_min += by + y_max += by + + shape_x, shape_y = self.get_camera_dimensions() + shape = (shape_x // binsize, shape_y // binsize) + + if mode == 'diff': + return self.stage.get_diffraction_pattern( + shape=shape, x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max + ) + else: + return self.stage.get_image( + shape=shape, x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max + ) + + def _mag_to_ranges(self, mag: float) -> Tuple[float, float, float, float]: + # assume 50x = 2mm full size + half_width = 50 * 1e6 / mag # 2mm/2 in nm is 1e6 + return -half_width, half_width, -half_width, half_width diff --git a/src/instamatic/simulation/crystal.py b/src/instamatic/simulation/crystal.py new file mode 100644 index 00000000..13875be3 --- /dev/null +++ b/src/instamatic/simulation/crystal.py @@ -0,0 +1,255 @@ +from __future__ import annotations + +from typing import Type, TypeVar + +import numpy as np +from diffpy import structure as diffpy + +Crystal_T = TypeVar('Crystal_T', bound='Crystal') + + +class Crystal: + def __init__( + self, a: float, b: float, c: float, alpha: float, beta: float, gamma: float + ) -> None: + """Simulate a primitive crystal given the unit cell. No additional + symmetry is imposed. + + Standard orientation as defined in diffpy. + + Parameters + ---------- + a : float + Unit cell length a, in Å + b : float + Unit cell length b, in Å + c : float + Unit cell length c, in Å + alpha : float + Angle between b and c, in degrees + beta : float + Angle between a and c, in degrees + gamma : float + Angle between a and b, in degrees + """ + self.a = a + self.b = b + self.c = c + self.alpha = alpha + self.beta = beta + self.gamma = gamma + + self.lattice = diffpy.Lattice(self.a, self.b, self.c, self.alpha, self.beta, self.gamma) + self.structure = diffpy.Structure( + atoms=[diffpy.Atom(xyz=[0, 0, 0])], + lattice=self.lattice, + ) + + @property + def a_vec(self) -> np.ndarray: + return self.lattice.cartesian((1, 0, 0)) + + @property + def b_vec(self) -> np.ndarray: + return self.lattice.cartesian((0, 1, 0)) + + @property + def c_vec(self) -> np.ndarray: + return self.lattice.cartesian((0, 0, 1)) + + @property + def a_star_vec(self) -> np.ndarray: + return self.lattice.reciprocal().cartesian((1, 0, 0)) + + @property + def b_star_vec(self) -> np.ndarray: + return self.lattice.reciprocal().cartesian((0, 1, 0)) + + @property + def c_star_vec(self) -> np.ndarray: + return self.lattice.reciprocal().cartesian((0, 0, 1)) + + @classmethod + def default(cls: Type[Crystal_T]) -> Crystal_T: + return cls(1, 2, 3, 90, 100, 110) + + def real_space_lattice(self, d_max: float) -> np.ndarray: + """Get the real space lattice as a (n, 3) shape array. + + Parameters + ---------- + d_max: float + The maximum d-spacing + + Returns + ------- + np.ndarray + Shape (n, 3), lattice points + """ + max_h = int(d_max // self.a) + max_k = int(d_max // self.b) + max_l = int(d_max // self.c) + hkls = np.array( + [ + (h, k, l) + for h in range(-max_h, max_h + 1) # noqa: E741 + for k in range(-max_k, max_k + 1) # noqa: E741 + for l in range(-max_l, max_l + 1) # noqa: E741 + ] + ) + vecs = self.lattice.cartesian(hkls) + return vecs + + def reciprocal_space_lattice(self, d_min: float) -> np.ndarray: + """Get the reciprocal space lattice as a (n, 3) shape array for input + n. + + Parameters + ---------- + d_min: float + Minimum d-spacing included + + Returns + ------- + np.ndarray + Shape (n, 3), lattice points + """ + max_h = int(d_min // self.lattice.ar) + max_k = int(d_min // self.lattice.br) + max_l = int(d_min // self.lattice.cr) + hkls = np.array( + [ + (h, k, l) + for h in range(-max_h, max_h + 1) # noqa: E741 + for k in range(-max_k, max_k + 1) # noqa: E741 + for l in range(-max_l, max_l + 1) # noqa: E741 + ] + ) + vecs = self.lattice.reciprocal().cartesian(hkls) + return vecs + + def diffraction_pattern_mask( + self, + shape: tuple[int, int], + d_min: float, + rotation_matrix: np.ndarray, + wavelength: float, + excitation_error: float, + ) -> np.ndarray: + """Get a diffraction pattern with a given shape, up to a given + resolution, in a given orientation and wavelength. + + Parameters + ---------- + shape : tuple[int, int] + Output shape + d_min : float + Minimum d-spacing, in Å + rotation_matrix : np.ndarray + Orientation + wavelength : float + Wavelength of incident beam, in Å + excitation_error : float + Excitation error used for intensity calculation, in reciprocal Å + + Returns + ------- + np.ndarray + Diffraction pattern + """ + # TODO calibration + out = np.zeros(shape, dtype=bool) + + # TODO this depends on convergence angle + spot_radius = 3 # pixels + + vecs = self.reciprocal_space_lattice(d_min) + d = np.sum(vecs**2, axis=1) + vecs = vecs[d < d_min**2] + + k = 2 * np.pi / wavelength + k_vec = rotation_matrix @ np.array([0, 0, -k]) + + # Find intersect with Ewald's sphere + q_squared = np.sum((vecs - k_vec) ** 2, axis=1) + vecs = vecs[ + (q_squared > (k - excitation_error) ** 2) + & (q_squared < (k + excitation_error) ** 2) + ] + + # Project onto screen + vecs_xy = (rotation_matrix.T @ vecs.T).T[:, :-1] # ignoring curvature + + # Make image + for vec in vecs_xy: + x = int(vec[0] * d_min * shape[1] / 2) + shape[1] // 2 + y = int(vec[1] * d_min * shape[0] / 2) + shape[0] // 2 + min_x = max(0, x - spot_radius) + max_x = min(shape[1], x + spot_radius) + min_y = max(0, y - spot_radius) + max_y = min(shape[0], y + spot_radius) + out[min_y:max_y, min_x:max_x] = 1 + return out + + def __str__(self) -> str: + return f'{self.__class__.__name__}(a = {self.a}, b = {self.b}, c = {self.c}, alpha = {self.alpha}, beta = {self.beta}, gamma = {self.gamma})' + + +class CubicCrystal(Crystal): + def __init__(self, a: float) -> None: + super().__init__(a, a, a, 90, 90, 90) + + @classmethod + def default(cls: Type[Crystal_T]) -> Crystal_T: + return cls(1) + + +class HexagonalCrystal(Crystal): + def __init__(self, a: float, c: float) -> None: + super().__init__(a, a, c, 90, 90, 120) + + @classmethod + def default(cls: Type[Crystal_T]) -> Crystal_T: + return cls(1, 2) + + +class TrigonalCrystal(Crystal): + def __init__(self, a: float, alpha: float) -> None: + super().__init__(a, a, a, alpha, alpha, alpha) + + @classmethod + def default(cls: Type[Crystal_T]) -> Crystal_T: + return cls(1, 100) + + +class TetragonalCrystal(Crystal): + def __init__(self, a: float, c: float) -> None: + super().__init__(a, a, c, 90, 90, 90) + + @classmethod + def default(cls: Type[Crystal_T]) -> Crystal_T: + return cls(1, 2) + + +class OrthorhombicCrystal(Crystal): + def __init__(self, a: float, b: float, c: float) -> None: + super().__init__(a, b, c, 90, 90, 90) + + @classmethod + def default(cls: Type[Crystal_T]) -> Crystal_T: + return cls(1, 2, 3) + + +class MonoclinicCrystal(Crystal): + def __init__(self, a: float, b: float, c: float, beta: float) -> None: + super().__init__(a, b, c, 90, beta, 90) + + @classmethod + def default(cls: Type[Crystal_T]) -> Crystal_T: + return cls(1, 2, 3, 100) + + +class TriclinicCrystal(Crystal): + @classmethod + def default(cls: Type[Crystal_T]) -> Crystal_T: + return cls(1, 2, 3, 90, 100, 110) diff --git a/src/instamatic/simulation/grid.py b/src/instamatic/simulation/grid.py new file mode 100644 index 00000000..9d093e63 --- /dev/null +++ b/src/instamatic/simulation/grid.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import warnings + +import numpy as np + +from instamatic.simulation.warnings import NotImplementedWarning + +# TODO carbon lace + + +class Grid: + def __init__( + self, + diameter: float = 3.05, + mesh: int = 200, + pitch: float = 125, + hole_width: float = 90, + bar_width: float = 35, + rim_width: float = 0.225, + ): + """TEM grid. + + Parameters + ---------- + diameter : float, optional + [mm] Total diameter, by default 3.05 + mesh : int, optional + [lines/inch] Hole density, by default 200 + pitch : float, optional + [µm], by default 125 + hole_width : float, optional + [µm], by default 90 + bar_width : float, optional + [µm], by default 35 + rim_width : float, optional + [mm], by default 0.225 + """ + # TODO make mesh set the pitch, bar width and pitch set the hole width ect. + self.diameter_mm = diameter + self.radius_nm = 1e6 * diameter / 2 + self.mesh = mesh + self.pitch_um = pitch + self.hole_width_um = hole_width + self.bar_width_um = bar_width + self.rim_width_mm = rim_width + self.grid_width_um = self.bar_width_um + self.hole_width_um + self.grid_width_nm = 1e3 * self.grid_width_um + + def get_rim_filter(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: + r_nm = 1e6 * (self.diameter_mm / 2 - self.rim_width_mm) + return x**2 + y**2 >= r_nm**2 + + def get_hole_filter(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: + x = np.remainder(np.abs(x), self.grid_width_nm) + y = np.remainder(np.abs(y), self.grid_width_nm) + + # Assume no grid in center, i.e. middle of the bar width + half_bar_width_nm = 1e3 * self.bar_width_um / 2 + return ( + (x < half_bar_width_nm) + | (x > (self.grid_width_nm - half_bar_width_nm)) + | (y < half_bar_width_nm) + | (y > (self.grid_width_nm - half_bar_width_nm)) + ) + + def get_center_mark(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: + # TODO + warnings.warn('Center mark is not implemented yet', NotImplementedWarning) + return np.zeros(x.shape, dtype=bool) + + def array_from_coords(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: + """Get mask array for given coordinate arrays (output from + np.meshgrid). (x, y) = (0, 0) is in the center of the grid. + + Parameters + ---------- + x : np.ndarray + x-coordinates + y : np.ndarray + y-coordinates + + Returns + ------- + np.ndarray + Mask array, False where the grid is blocking + """ + rim_filter = self.get_rim_filter(x, y) + grid_filter = self.get_hole_filter(x, y) + + # TODO proper logic for this, + # as the mark includes a hole in the center which will be overridden by the grid filter + center_mark = self.get_center_mark(x, y) + + return rim_filter | grid_filter | center_mark + + def array( + self, + shape: tuple[int, int], + x_min: float, + x_max: float, + y_min: float, + y_max: float, + ) -> np.ndarray: + """Get mask array for given ranges. (x, y) = (0, 0) is in the center of + the grid. + + Parameters + ---------- + shape : tuple[int, int] + Output shape + x_min : float + [nm] Lower bound for x (left) + x_max : float + [nm] Upper bound for x (right) + y_min : float + [nm] Lower bound for y (bottom) + y_max : float + [nm] Upper bound for y (top) + + Returns + ------- + np.ndarray + Mask array, False where the grid is blocking + """ + x, y = np.meshgrid( + np.linspace(x_min, x_max, shape[1]), + np.linspace(y_min, y_max, shape[0]), + ) + return self.array_from_coords(x, y) diff --git a/src/instamatic/simulation/sample.py b/src/instamatic/simulation/sample.py new file mode 100644 index 00000000..f4786f58 --- /dev/null +++ b/src/instamatic/simulation/sample.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + + +@dataclass +class Sample: + x: float + y: float + r: float + thickness: float # between 0 and 1 + euler_angle_phi_1: float + euler_angle_psi: float + euler_angle_phi_2: float + crystal_index: int = 0 # used for lookup in a list of crystals + + def __post_init__(self): + cp1 = np.cos(self.euler_angle_phi_1) + cp = np.cos(self.euler_angle_psi) + cp2 = np.cos(self.euler_angle_phi_2) + sp1 = np.sin(self.euler_angle_phi_1) + sp = np.sin(self.euler_angle_psi) + sp2 = np.sin(self.euler_angle_phi_2) + r1 = np.array([[cp1, sp1, 0], [-sp1, cp1, 0], [0, 0, 1]]) + r2 = np.array([[1, 0, 0], [0, cp, sp], [0, -sp, cp]]) + r3 = np.array([[cp2, sp2, 0], [-sp2, cp2, 0], [0, 0, 1]]) + self.rotation_matrix = r1 @ r2 @ r3 + + def pixel_contains_crystal(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: + """Given arrays of x- and y- coordinates in the lab frame, calculate + whether the crystal overlaps with these positions. + + Parameters + ---------- + x : np.ndarray + x coordinates + y : np.ndarray + y coordinates + + Returns + ------- + np.ndarray + Same shape as inputs, dtype bool + """ + return (x - self.x) ** 2 + (y - self.y) ** 2 < self.r**2 + + def range_might_contain_crystal( + self, + x_min: float, + x_max: float, + y_min: float, + y_max: float, + ) -> bool: + """Simple estimate of whether a range contains the crystal. This check + is fast but inaccurate. False positives are possible, false negatives + are impossible. + + Parameters + ---------- + x_min : float + Lower bound for x + x_max : float + Upper bound for x + y_min : float + Lower bound for y + y_max : float + Upper bound for y + + Returns + ------- + bool + True if range contains crystal + """ + # TODO ensure the docstring is true, regarding false negatives. + # TODO improve estimate? + # TODO handle this correctly when stage is rotated... + in_x = x_min <= self.x + self.r and self.x - self.r <= x_max + in_y = y_min <= self.y + self.r and self.y - self.r <= y_max + return in_x and in_y diff --git a/src/instamatic/simulation/stage.py b/src/instamatic/simulation/stage.py new file mode 100644 index 00000000..a5b57036 --- /dev/null +++ b/src/instamatic/simulation/stage.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +import warnings + +import numpy as np +from scipy.spatial.transform import Rotation + +from instamatic.simulation.crystal import Crystal +from instamatic.simulation.grid import Grid +from instamatic.simulation.sample import Sample +from instamatic.simulation.warnings import NotImplementedWarning + + +class Stage: + def __init__( + self, + num_crystals: int = 100_000, + min_crystal_size: float = 100, + max_crystal_size: float = 1000, + random_seed: int = 100, + ) -> None: + """Handle many samples on a grid. + + Parameters + ---------- + num_crystals : int, optional + Number of crystals to disperse on the grid, by default 100_000 + min_crystal_size : float, optional + Minimum radius of the crystals, in nm, by default 100 + max_crystal_size : float, optional + Maximum radius of the crystals, in nm, by default 1000 + random_seed : int, optional + Seed for random number generation, by default 100 + """ + # TODO make this settable + self.x = 0 + self.y = 0 + self.z = 0 + self.alpha_tilt = 0 + self.beta_tilt = 0 + self.in_plane_rotation = 0 # TODO change this with focus/magnification + self.rotation_matrix = np.eye(3) + self.origin = np.array([0, 0, 0]) + + # TODO parameters + self.grid = Grid() + + self.rng = np.random.Generator(np.random.PCG64(random_seed)) + + # TODO parameters + # TODO multiple phases + # TODO amorphous phase + self.crystal = Crystal(*self.rng.uniform(5, 25, 3), *self.rng.uniform(80, 110, 3)) + + self.samples = [ + Sample( + x=self.rng.uniform(-self.grid.radius_nm, self.grid.radius_nm), + y=self.rng.uniform(-self.grid.radius_nm, self.grid.radius_nm), + r=self.rng.uniform(min_crystal_size, max_crystal_size), + thickness=self.rng.uniform(0, 1), + euler_angle_phi_1=self.rng.uniform(0, 2 * np.pi), + euler_angle_psi=self.rng.uniform(0, np.pi), + euler_angle_phi_2=self.rng.uniform(0, 2 * np.pi), + ) + for _ in range(num_crystals) + ] + + def set_position( + self, + x: float = None, + y: float = None, + z: float = None, + alpha_tilt: float = None, + beta_tilt: float = None, + ): + if x is not None: + self.x = x + if y is not None: + self.y = y + if z is not None: + self.z = z + if alpha_tilt is not None: + warnings.warn( + 'Tilting is not fully implemented yet', + NotImplementedWarning, + stacklevel=2, + ) + self.alpha_tilt = alpha_tilt + if beta_tilt is not None: + warnings.warn( + 'Tilting is not fully implemented yet', + NotImplementedWarning, + stacklevel=2, + ) + self.beta_tilt = beta_tilt + + # TODO define orientation. Is this matrix multiplied with lab coordinates to get sample coordinates? + self.rotation_matrix = Rotation.from_euler( + 'ZXY', + [self.in_plane_rotation, self.alpha_tilt, self.beta_tilt], + degrees=True, + ).as_matrix() + + def image_extent_to_sample_coordinates( + self, + shape: tuple[int, int], + x_min: float, + x_max: float, + y_min: float, + y_max: float, + ) -> tuple[np.ndarray, np.ndarray]: + """Get arrays of grid positions with a given shape and extent in lab + coordinates. + + Parameters + ---------- + shape : tuple[int, int] + Output shape + x_min : float + Lower bound of x + x_max : float + Upper bound of x + y_min : float + Lower bound of y + y_max : float + Upper bound of y + + Returns + ------- + tuple[np.ndarray, np.ndarray] + x, y. 2D arrays of floats + """ + if self.alpha_tilt != 0 or self.beta_tilt != 0: + warnings.warn( + 'Tilting is not fully implemented yet', NotImplementedWarning, stacklevel=2 + ) + # https://en.wikipedia.org/wiki/Line%E2%80%93plane_intersection + n = self.rotation_matrix @ np.array([0, 0, 1]) + p0 = self.origin + l = np.array([0, 0, 1]) # noqa: E741 + l0 = np.array( + [ + p.flatten() + for p in np.meshgrid( + np.linspace(x_min, x_max, shape[1]), + np.linspace(y_min, y_max, shape[0]), + [0], + ) + ] + ) + + p = l0 + np.array([0, 0, 1])[:, np.newaxis] * np.dot(-l0.T + p0, n) / np.dot(l, n) + + x, y, z = self.rotation_matrix.T @ p + x = x.reshape(shape) + y = y.reshape(shape) + return x, y + + def get_image( + self, + shape: tuple[int, int], + x_min: float, + x_max: float, + y_min: float, + y_max: float, + ) -> np.ndarray: + """Get image array for given ranges. (x, y) = (0, 0) is in the center + of the grid. + + Parameters + ---------- + shape : tuple[int, int] + Output shape + x_min : float + [nm] Lower bound for x (left) + x_max : float + [nm] Upper bound for x (right) + y_min : float + [nm] Lower bound for y (bottom) + y_max : float + [nm] Upper bound for y (top) + + Returns + ------- + np.ndarray + Image + """ + x, y = self.image_extent_to_sample_coordinates( + shape=shape, x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max + ) + + grid_mask = self.grid.array_from_coords(x, y) + + sample_data = np.ones(shape, dtype=int) * 1000 + for ind, sample in enumerate(self.samples): + if not sample.range_might_contain_crystal( + x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max + ): + continue + # TODO better logic here + sample_data[sample.pixel_contains_crystal(x, y)] = 1000 * (1 - sample.thickness) + + sample_data[grid_mask] = 0 + + return sample_data + + def get_diffraction_pattern( + self, + shape: tuple[int, int], + x_min: float, + x_max: float, + y_min: float, + y_max: float, + camera_length: float = 150, + ) -> np.ndarray: + """Get diffraction pattern array for given ranges. (x, y) = (0, 0) is + in the center of the grid. + + Parameters + ---------- + shape : tuple[int, int] + Output shape + x_min : float + [nm] Lower bound for x (left) + x_max : float + [nm] Upper bound for x (right) + y_min : float + [nm] Lower bound for y (bottom) + y_max : float + [nm] Upper bound for y (top) + camera_length : float + [cm] Camera length, for calibration + + Returns + ------- + np.ndarray + diffraction pattern + """ + x, y = self.image_extent_to_sample_coordinates( + shape=shape, x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max + ) + d_min = 1.0 + + grid_mask = self.grid.array_from_coords(x, y) + + if np.all(grid_mask): + # no transmission + return np.zeros(shape, dtype=int) + + reflections = np.zeros(shape, dtype=bool) + + # Direct beam + reflections[ + shape[0] // 2 - 4 : shape[0] // 2 + 4, shape[1] // 2 - 4 : shape[1] // 2 + 4 + ] = 1 + + for sample in self.samples: + if not sample.range_might_contain_crystal( + x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max + ): + continue + pos = sample.pixel_contains_crystal(x, y) + if np.all(grid_mask[pos]): + # Crystal is completely on the grid + continue + + reflections |= self.crystal.diffraction_pattern_mask( + shape, + d_min=d_min, + rotation_matrix=self.rotation_matrix @ sample.rotation_matrix, + wavelength=0.02, + excitation_error=0.01, + ) + # TODO diffraction shift + + # TODO noise + + # Simple scaling + # TODO improve, proper form factors maybe + # TODO camera length + kx, ky = np.meshgrid( + np.linspace(-1 / d_min, 1 / d_min, shape[1]), + np.linspace(-1 / d_min, 1 / d_min, shape[0]), + ) + k_squared = kx**2 + ky**2 + scale = 1 / (3 * k_squared + 1) + + scale[~reflections] = 0 + + # Convert to int array + scale = (scale * 0x8000).astype(int) + + return scale diff --git a/src/instamatic/simulation/warnings.py b/src/instamatic/simulation/warnings.py new file mode 100644 index 00000000..dfca684b --- /dev/null +++ b/src/instamatic/simulation/warnings.py @@ -0,0 +1,5 @@ +from __future__ import annotations + + +class NotImplementedWarning(UserWarning): + pass diff --git a/tests/test_simulation/__init__.py b/tests/test_simulation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_simulation/test_crystal.py b/tests/test_simulation/test_crystal.py new file mode 100644 index 00000000..8a1ad4ea --- /dev/null +++ b/tests/test_simulation/test_crystal.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from typing import Type + +import pytest + +from instamatic.simulation.crystal import ( + Crystal, + CubicCrystal, + HexagonalCrystal, + MonoclinicCrystal, + OrthorhombicCrystal, + TetragonalCrystal, + TriclinicCrystal, + TrigonalCrystal, +) + + +def test_crystal_init(): + Crystal(1, 1, 1, 1, 1, 1) + + +@pytest.mark.parametrize( + 'crystal', + [ + Crystal, + CubicCrystal, + HexagonalCrystal, + TrigonalCrystal, + TetragonalCrystal, + OrthorhombicCrystal, + MonoclinicCrystal, + TriclinicCrystal, + ], +) +def test_crystal_default(crystal: Type[Crystal]): + c = crystal.default() + assert isinstance(c, Crystal) + + +def test_get_lattice_cubic(): + c = CubicCrystal.default() + lat = c.real_space_lattice(1) + assert pytest.approx(lat) == [ + (-1, -1, -1), + (-1, -1, 0), + (-1, -1, 1), + (-1, 0, -1), + (-1, 0, 0), + (-1, 0, 1), + (-1, 1, -1), + (-1, 1, 0), + (-1, 1, 1), + (0, -1, -1), + (0, -1, 0), + (0, -1, 1), + (0, 0, -1), + (0, 0, 0), + (0, 0, 1), + (0, 1, -1), + (0, 1, 0), + (0, 1, 1), + (1, -1, -1), + (1, -1, 0), + (1, -1, 1), + (1, 0, -1), + (1, 0, 0), + (1, 0, 1), + (1, 1, -1), + (1, 1, 0), + (1, 1, 1), + ] + lat = c.reciprocal_space_lattice(1) + assert pytest.approx(lat) == [ + (-1, -1, -1), + (-1, -1, 0), + (-1, -1, 1), + (-1, 0, -1), + (-1, 0, 0), + (-1, 0, 1), + (-1, 1, -1), + (-1, 1, 0), + (-1, 1, 1), + (0, -1, -1), + (0, -1, 0), + (0, -1, 1), + (0, 0, -1), + (0, 0, 0), + (0, 0, 1), + (0, 1, -1), + (0, 1, 0), + (0, 1, 1), + (1, -1, -1), + (1, -1, 0), + (1, -1, 1), + (1, 0, -1), + (1, 0, 0), + (1, 0, 1), + (1, 1, -1), + (1, 1, 0), + (1, 1, 1), + ] diff --git a/tests/test_simulation/test_grid.py b/tests/test_simulation/test_grid.py new file mode 100644 index 00000000..f1116517 --- /dev/null +++ b/tests/test_simulation/test_grid.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from instamatic.simulation.grid import Grid + + +def test_init(): + Grid() + + +def test_get_array(): + g = Grid( + hole_width=9, + bar_width=1, + ) + arr = g.array( + shape=(40, 40), + x_min=g.grid_width_nm, + x_max=3 * g.grid_width_nm, + y_min=-4 * g.grid_width_nm, + y_max=-2 * g.grid_width_nm, + ) + + # Grid + assert np.all(arr[:, 0]) + assert np.all(arr[:, -1]) + assert np.all(arr[0, :]) + assert np.all(arr[-1, :]) + assert np.all(arr[:, 19]) + assert np.all(arr[:, 20]) + assert np.all(arr[19, :]) + assert np.all(arr[20, :]) + + # Holes + assert np.sum(arr[1:19, 1:19]) == 0 + assert np.sum(arr[1:19, 21:-1]) == 0 + assert np.sum(arr[21:-1, 1:19]) == 0 + assert np.sum(arr[21:-1, 21:-1]) == 0 + + +@pytest.mark.xfail(reason='TODO') +def test_get_array_including_center(): + assert False, 'TODO' + + +@pytest.mark.xfail(reason='TODO') +def test_get_array_including_rim(): + assert False, 'TODO' diff --git a/tests/test_simulation/test_sample.py b/tests/test_simulation/test_sample.py new file mode 100644 index 00000000..f01022dd --- /dev/null +++ b/tests/test_simulation/test_sample.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import pytest + +from instamatic.simulation.sample import Sample + + +def test_init(): + s = Sample(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + assert isinstance(s, Sample) + + +def test_range_might_contain_crystal(): + s = Sample(0, 0, 1, 0, 0, 0, 0) + assert s.range_might_contain_crystal(-1, 1, -1, 1) + assert not s.range_might_contain_crystal(9, 10, 9, 10) + assert not s.range_might_contain_crystal(-10, -9, -10, -9) + assert not s.range_might_contain_crystal(-10, -9, -10, 10) + assert not s.range_might_contain_crystal(-10, 10, -10, -9) + assert s.range_might_contain_crystal(0, 1, 0, 1) + # TODO expand + + +@pytest.mark.xfail(reason='TODO') +def test_pixel_contains_crystal(): + assert False, 'TODO' + + +def test_range_might_contain_crystal_false_positive(): + s = Sample(0, 0, 1, 0, 0, 0, 0) + x_min = 0.9 + x_max = 1 + y_min = 0.9 + y_max = 1 + assert s.range_might_contain_crystal(x_min, x_max, y_min, y_max) + assert not s.pixel_contains_crystal(x_min, y_min) + # TODO expand + + +@pytest.mark.xfail(reason='Need to figure out how this can be done') +def test_range_might_contain_crystal_false_negative(): + assert False, 'TODO' diff --git a/tests/test_simulation/test_stage.py b/tests/test_simulation/test_stage.py new file mode 100644 index 00000000..bd6e4c00 --- /dev/null +++ b/tests/test_simulation/test_stage.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import pytest + +from instamatic.simulation.stage import Stage +from instamatic.simulation.warnings import NotImplementedWarning + + +def test_init_default(): + s = Stage() + assert isinstance(s, Stage) + + +def test_set_position(): + s = Stage() + s.set_position(x=10) + assert s.x == 10 + s.set_position(z=10) + assert s.x == 10 + assert s.z == 10 + with pytest.warns(NotImplementedWarning): + s.set_position(alpha_tilt=1) + with pytest.warns(NotImplementedWarning): + s.set_position(beta_tilt=1) + with pytest.warns(NotImplementedWarning): + s.set_position(x=1, y=1, z=1, alpha_tilt=1, beta_tilt=1) + + +@pytest.mark.xfail(reason='TODO') +def test_tilt(): + # Somehow check that the projected coordinates using stage.image_extent_to_sample_coordinates are correct + assert False, 'TODO' + + +@pytest.mark.xfail(reason='TODO') +def test_image_rotation(): + # Image rotates with focus ect. + assert False, 'TODO'