Skip to content

Commit

Permalink
new embedding with image masking
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Apr 3, 2023
1 parent fc10aa8 commit 9167281
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 10 deletions.
Binary file added data/protein_models/6wxb_bending_models.npy
Binary file not shown.
4 changes: 2 additions & 2 deletions src/cryo_sbi/inference/NPE_train_from_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ def npe_train_from_disk(
shuffle=True,
batch_size=train_config["BATCH_SIZE"],
chunk_size=train_config["BATCH_SIZE"],
chunk_step=2 ** 2,
chunk_step=2**2,
)

validset = H5Dataset(
val_data_dir,
shuffle=True,
batch_size=train_config["BATCH_SIZE"],
chunk_size=train_config["BATCH_SIZE"],
chunk_step=2 ** 2,
chunk_step=2**2,
)

train_loader = torch.utils.data.DataLoader(
Expand Down
28 changes: 27 additions & 1 deletion src/cryo_sbi/inference/models/embedding_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
import torchvision.models as models

from cryo_sbi.utils.image_utils import LowPassFilter
from cryo_sbi.utils.image_utils import LowPassFilter, Mask


EMBEDDING_NETS = {}
Expand Down Expand Up @@ -338,5 +338,31 @@ def forward(self, x):
return x


@add_embedding("RESNET18_FFT_FILTER_MASK")
class ResNet18_FFT_Encoder(nn.Module):
def __init__(self, output_dimension):
super(ResNet18_FFT_Encoder, self).__init__()
self.resnet = models.resnet18()
self.resnet.conv1 = nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
self.resnet.fc = nn.Linear(
in_features=512, out_features=output_dimension, bias=True
)

self._fft_filter = LowPassFilter(128, 80)
self._masking = Mask(128, 40)

def forward(self, x):
# Masking images
x = self._masking(x)
# Low pass filter images
x = self._fft_filter(x)
# Proceed as normal
x = x.unsqueeze(1)
x = self.resnet(x)
return x


if __name__ == "__main__":
pass
4 changes: 2 additions & 2 deletions src/cryo_sbi/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ def circular_mask(n_pixels, radius, inside=True):
r_2d = grid[None, :] ** 2 + grid[:, None] ** 2

if inside is True:
mask = r_2d < radius ** 2
mask = r_2d < radius**2
else:
mask = r_2d > radius ** 2
mask = r_2d > radius**2

return mask

Expand Down
2 changes: 1 addition & 1 deletion src/cryo_sbi/wpa_simulator/ctf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def calc_ctf(image_params):

x, y = torch.meshgrid(freq_pix_1d, freq_pix_1d, indexing="ij")

freq2_2d = x ** 2 + y ** 2
freq2_2d = x**2 + y**2
imag = torch.zeros_like(freq2_2d) * 1j

env = torch.exp(-image_params["B_FACTOR"] * freq2_2d * 0.5)
Expand Down
8 changes: 5 additions & 3 deletions src/cryo_sbi/wpa_simulator/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def gen_quat():
quat = np.random.uniform(
-1, 1, 4
) # note this is a half-open interval, so 1 is not included but -1 is
norm = np.sqrt(np.sum(quat ** 2))
norm = np.sqrt(np.sum(quat**2))

if 0.2 <= norm <= 1.0:
quat /= norm
Expand All @@ -29,9 +29,11 @@ def gen_img(coord, image_params):
)

else:
raise ValueError("SIGMA should be a single value or a list of [min_sigma, max_sigma]")
raise ValueError(
"SIGMA should be a single value or a list of [min_sigma, max_sigma]"
)

norm = 1 / (2 * torch.pi * atom_sigma ** 2 * n_atoms)
norm = 1 / (2 * torch.pi * atom_sigma**2 * n_atoms)

grid_min = -image_params["PIXEL_SIZE"] * (image_params["N_PIXELS"] - 1) * 0.5
grid_max = (
Expand Down
2 changes: 1 addition & 1 deletion src/cryo_sbi/wpa_simulator/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
def circular_mask(n_pixels, radius):
grid = torch.linspace(-0.5 * (n_pixels - 1), 0.5 * (n_pixels - 1), n_pixels)
r_2d = grid[None, :] ** 2 + grid[:, None] ** 2
mask = r_2d < radius ** 2
mask = r_2d < radius**2

return mask

Expand Down

0 comments on commit 9167281

Please sign in to comment.