Skip to content

Commit

Permalink
added embedding nets for all expermental setups
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Sep 25, 2023
1 parent f4e2481 commit 29353ec
Show file tree
Hide file tree
Showing 6 changed files with 1,921 additions and 56 deletions.
122 changes: 74 additions & 48 deletions notebooks/Untitled.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion notebooks/analysis_nma_refurbed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.10.0"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
Expand Down
1,792 changes: 1,792 additions & 0 deletions notebooks/test_low_pass_filter.ipynb

Large diffs are not rendered by default.

46 changes: 46 additions & 0 deletions src/cryo_sbi/inference/models/embedding_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,52 @@ def forward(self, x):
x = x.unsqueeze(1)
x = self.resnet(x)
return x


@add_embedding("RESNET18_FFT_FILTER_224")
class ResNet18_FFT_Encoder_224(nn.Module):
def __init__(self, output_dimension: int):
super(ResNet18_FFT_Encoder_224, self).__init__()
self.resnet = models.resnet18()
self.resnet.conv1 = nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
self.resnet.fc = nn.Linear(
in_features=512, out_features=output_dimension, bias=True
)

self._fft_filter = LowPassFilter(224, 25)

def forward(self, x):
# Low pass filter images
x = self._fft_filter(x)
# Proceed as normal
x = x.unsqueeze(1)
x = self.resnet(x)
return x


@add_embedding("RESNET18_FFT_FILTER_256")
class ResNet18_FFT_Encoder_256(nn.Module):
def __init__(self, output_dimension: int):
super(ResNet18_FFT_Encoder_256, self).__init__()
self.resnet = models.resnet18()
self.resnet.conv1 = nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
self.resnet.fc = nn.Linear(
in_features=512, out_features=output_dimension, bias=True
)

self._fft_filter = LowPassFilter(256, 10)

def forward(self, x):
# Low pass filter images
x = self._fft_filter(x)
# Proceed as normal
x = x.unsqueeze(1)
x = self.resnet(x)
return x


@add_embedding("RESNET34")
Expand Down
4 changes: 2 additions & 2 deletions src/cryo_sbi/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ def _extract_num_particles(self, path):
future_mrc = mrcfile.open_async(path)
mrc = future_mrc.result()
data_shape = mrc.data.shape
img_stack = mrc.is_image_stack()
num_images = data_shape[0] if img_stack else 1
#img_stack = mrc.is_image_stack()
num_images = data_shape[0] if len(data_shape) > 2 else 1
return num_images

def build_index_map(self):
Expand Down
11 changes: 6 additions & 5 deletions src/cryo_sbi/utils/pdb_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,14 @@ def pdb_parser_(input_file: str, mode: str) -> torch.tensor:
return atomic_model


def pdb_parser(input_file_prefix, n_pdbs, output_file, mode):
def pdb_parser(file_formatter, n_pdbs, output_file, mode, start_index=1):
"""
Parses a pdb file and returns an atomic model of the protein. The atomic model is a 5xN array, where N is the number of atoms or residues in the protein. The first three rows are the x, y, z coordinates of the atoms or residues. The fourth row is the atomic number of the atoms or the density of the residues. The fifth row is the variance of the atoms or residues, which is the resolution of the cryo-EM map divided by pi squared.
Parameters
----------
input_file_prefix : str
The path to the pdb file. The pdb files should be named as input_file_prefix0.pdb, input_file_prefix1.pdb, etc.
file_formatter : str
The path to the pdb file. The path must contain the placeholder {} for the pdb index. For example, if the path is "data/pdb/{}.pdb", then the placeholder is {}.
n_pdbs : int
The number of pdb files to parse.
output_file : str
Expand All @@ -177,11 +177,12 @@ def pdb_parser(input_file_prefix, n_pdbs, output_file, mode):
The mode of the atomic model. Either "resid" or "all atom". Resid mode returns a coarse grained atomic model of the protein. All atom mode returns an all atom atomic model of the protein.
"""

atomic_model = pdb_parser_(f"{input_file_prefix}{1}.pdb", mode)
atomic_model = pdb_parser_(file_formatter.format(start_index), mode)
atomic_models = torch.zeros((n_pdbs, *atomic_model.shape))

for i in range(0, n_pdbs):
atomic_models[i] = pdb_parser_(f"{input_file_prefix}{i+1}.pdb", mode)
atomic_models[i] = pdb_parser_(file_formatter.format(start_index+1), mode)


if output_file.endswith("pt"):
torch.save(atomic_models, output_file)
Expand Down

0 comments on commit 29353ec

Please sign in to comment.