diff --git a/src/cryo_sbi/wpa_simulator/ctf.py b/src/cryo_sbi/wpa_simulator/ctf.py index d7fa173..3a244a3 100644 --- a/src/cryo_sbi/wpa_simulator/ctf.py +++ b/src/cryo_sbi/wpa_simulator/ctf.py @@ -3,6 +3,21 @@ def calc_ctf(image_params): + """Calculate the CTF for a given image size and defocus + + Args: + image_params (dict): Dictionary containing the image parameters + N_PIXELS (int): Number of pixels in the image + PIXEL_SIZE (float): Pixel size in Angstrom + DEFOCUS (float or list): Defocus in Angstrom + B_FACTOR (float): B-factor in Angstrom + AMP (float): Amplitude contrast + ELECWAVE (float): Electron wavelength in Angstrom + + Returns: + ctf (torch.Tensor): CTF for the given image size and defocus + """ + # Attention look into def pad_image function to know the image size after padding image_size = ( 2 * (int(np.ceil(image_params["N_PIXELS"] * 0.1)) + 1) @@ -42,6 +57,16 @@ def calc_ctf(image_params): def apply_ctf(image, ctf): + """Apply the CTF to an image. + + Args: + image (torch.Tensor): Image to apply the CTF to + ctf (torch.Tensor): CTF to apply to the image + + Returns: + image_ctf (torch.Tensor): Image with the CTF applied + """ + conv_image_ctf = torch.fft.fft2(image) * ctf image_ctf = torch.fft.ifft2(conv_image_ctf).real diff --git a/src/cryo_sbi/wpa_simulator/image_generation.py b/src/cryo_sbi/wpa_simulator/image_generation.py index 6ded7eb..52fc4cf 100644 --- a/src/cryo_sbi/wpa_simulator/image_generation.py +++ b/src/cryo_sbi/wpa_simulator/image_generation.py @@ -3,6 +3,12 @@ def gen_quat(): + """Generate a random quaternion. + + Returns: + quat (np.ndarray): Random quaternion + + """ count = 0 while count < 1: quat = np.random.uniform( @@ -18,6 +24,20 @@ def gen_quat(): def gen_img(coord, image_params): + """Generate an image from a set of coordinates. + + Args: + coord (torch.Tensor): Coordinates of the atoms in the image + image_params (dict): Dictionary containing the image parameters + N_PIXELS (int): Number of pixels along one image size. + PIXEL_SIZE (float): Pixel size in Angstrom + SIGMA (float or list): Standard deviation of the Gaussian function used to model electron density. + ELECWAVE (float): Electron wavelength in Angstrom + + Returns: + image (torch.Tensor): Image generated from the coordinates + """ + n_atoms = coord.shape[1] if isinstance(image_params["SIGMA"], float):