diff --git a/data/protein_models/6wxb_bending_models.npy b/data/protein_models/6wxb_bending_models.npy new file mode 100644 index 0000000..0636038 Binary files /dev/null and b/data/protein_models/6wxb_bending_models.npy differ diff --git a/src/cryo_sbi/inference/NPE_train_from_disk.py b/src/cryo_sbi/inference/NPE_train_from_disk.py index 22344a3..3c8d25f 100644 --- a/src/cryo_sbi/inference/NPE_train_from_disk.py +++ b/src/cryo_sbi/inference/NPE_train_from_disk.py @@ -41,7 +41,7 @@ def npe_train_from_disk( shuffle=True, batch_size=train_config["BATCH_SIZE"], chunk_size=train_config["BATCH_SIZE"], - chunk_step=2 ** 2, + chunk_step=2**2, ) validset = H5Dataset( @@ -49,7 +49,7 @@ def npe_train_from_disk( shuffle=True, batch_size=train_config["BATCH_SIZE"], chunk_size=train_config["BATCH_SIZE"], - chunk_step=2 ** 2, + chunk_step=2**2, ) train_loader = torch.utils.data.DataLoader( diff --git a/src/cryo_sbi/inference/models/embedding_nets.py b/src/cryo_sbi/inference/models/embedding_nets.py index 8906670..021bbcc 100644 --- a/src/cryo_sbi/inference/models/embedding_nets.py +++ b/src/cryo_sbi/inference/models/embedding_nets.py @@ -2,7 +2,7 @@ import torch.nn as nn import torchvision.models as models -from cryo_sbi.utils.image_utils import LowPassFilter +from cryo_sbi.utils.image_utils import LowPassFilter, Mask EMBEDDING_NETS = {} @@ -338,5 +338,31 @@ def forward(self, x): return x +@add_embedding("RESNET18_FFT_FILTER_MASK") +class ResNet18_FFT_Encoder(nn.Module): + def __init__(self, output_dimension): + super(ResNet18_FFT_Encoder, self).__init__() + self.resnet = models.resnet18() + self.resnet.conv1 = nn.Conv2d( + 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False + ) + self.resnet.fc = nn.Linear( + in_features=512, out_features=output_dimension, bias=True + ) + + self._fft_filter = LowPassFilter(128, 80) + self._masking = Mask(128, 40) + + def forward(self, x): + # Masking images + x = self._masking(x) + # Low pass filter images + x = self._fft_filter(x) + # Proceed as normal + x = x.unsqueeze(1) + x = self.resnet(x) + return x + + if __name__ == "__main__": pass diff --git a/src/cryo_sbi/utils/image_utils.py b/src/cryo_sbi/utils/image_utils.py index 19ad051..e548512 100644 --- a/src/cryo_sbi/utils/image_utils.py +++ b/src/cryo_sbi/utils/image_utils.py @@ -8,9 +8,9 @@ def circular_mask(n_pixels, radius, inside=True): r_2d = grid[None, :] ** 2 + grid[:, None] ** 2 if inside is True: - mask = r_2d < radius ** 2 + mask = r_2d < radius**2 else: - mask = r_2d > radius ** 2 + mask = r_2d > radius**2 return mask diff --git a/src/cryo_sbi/wpa_simulator/ctf.py b/src/cryo_sbi/wpa_simulator/ctf.py index b737f42..d7fa173 100644 --- a/src/cryo_sbi/wpa_simulator/ctf.py +++ b/src/cryo_sbi/wpa_simulator/ctf.py @@ -29,7 +29,7 @@ def calc_ctf(image_params): x, y = torch.meshgrid(freq_pix_1d, freq_pix_1d, indexing="ij") - freq2_2d = x ** 2 + y ** 2 + freq2_2d = x**2 + y**2 imag = torch.zeros_like(freq2_2d) * 1j env = torch.exp(-image_params["B_FACTOR"] * freq2_2d * 0.5) diff --git a/src/cryo_sbi/wpa_simulator/image_generation.py b/src/cryo_sbi/wpa_simulator/image_generation.py index b4adaae..6ded7eb 100644 --- a/src/cryo_sbi/wpa_simulator/image_generation.py +++ b/src/cryo_sbi/wpa_simulator/image_generation.py @@ -8,7 +8,7 @@ def gen_quat(): quat = np.random.uniform( -1, 1, 4 ) # note this is a half-open interval, so 1 is not included but -1 is - norm = np.sqrt(np.sum(quat ** 2)) + norm = np.sqrt(np.sum(quat**2)) if 0.2 <= norm <= 1.0: quat /= norm @@ -29,9 +29,11 @@ def gen_img(coord, image_params): ) else: - raise ValueError("SIGMA should be a single value or a list of [min_sigma, max_sigma]") + raise ValueError( + "SIGMA should be a single value or a list of [min_sigma, max_sigma]" + ) - norm = 1 / (2 * torch.pi * atom_sigma ** 2 * n_atoms) + norm = 1 / (2 * torch.pi * atom_sigma**2 * n_atoms) grid_min = -image_params["PIXEL_SIZE"] * (image_params["N_PIXELS"] - 1) * 0.5 grid_max = ( diff --git a/src/cryo_sbi/wpa_simulator/noise.py b/src/cryo_sbi/wpa_simulator/noise.py index 0511310..9d56a15 100644 --- a/src/cryo_sbi/wpa_simulator/noise.py +++ b/src/cryo_sbi/wpa_simulator/noise.py @@ -5,7 +5,7 @@ def circular_mask(n_pixels, radius): grid = torch.linspace(-0.5 * (n_pixels - 1), 0.5 * (n_pixels - 1), n_pixels) r_2d = grid[None, :] ** 2 + grid[:, None] ** 2 - mask = r_2d < radius ** 2 + mask = r_2d < radius**2 return mask