Skip to content

Commit

Permalink
started to add tutorial, added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jan 29, 2024
1 parent 343c86c commit 7caf5bf
Show file tree
Hide file tree
Showing 15 changed files with 489 additions and 1,288 deletions.
33 changes: 33 additions & 0 deletions src/cryo_sbi/inference/models/embedding_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,39 @@ def forward(self, x):
x = self.avgpool(x).flatten(start_dim=1)
x = self.feedforward(x)
return x


@add_embedding("ConvEncoder_Tutorial")
class ConvEncoder(nn.Module):
def __init__(self, output_dimension: int):
super(ConvEncoder, self).__init__()
ndf = 16 # fixed for the tutorial
self.main = nn.Sequential(
# input is 1 x 64 x 64
nn.Conv2d(1, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
# nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
# nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
# nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, output_dimension, 4, 1, 0, bias=False),
# state size. out_dims x 1 x 1
)

def forward(self, x):
x = x.view(-1, 1, 64, 64)
x = self.main(x)
return x.view(x.size(0), -1) # flatten



if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def test_embedding(embedding_name, num_images, out_dim):
if "FFT_FILTER_" in embedding_name:
size = embedding_name.split("FFT_FILTER_")[1]
test_images = torch.randn(num_images, int(size), int(size))
elif "Tutorial" in embedding_name:
test_images = torch.randn(num_images, 64, 64)
else:
test_images = torch.randn(num_images, 128, 128)

Expand Down
10 changes: 10 additions & 0 deletions tests/test_image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ def test_low_pass_filter():
assert filtered_image.shape == (image_size, image_size)


def test_gaussian_low_pass_filter():
image_size = 100
frequency_cutoff = 30
low_pass_filter = iu.GaussianLowPassFilter(image_size, frequency_cutoff)
image = torch.ones((image_size, image_size))

filtered_image = low_pass_filter(image)
assert filtered_image.shape == (image_size, image_size)


def test_normalize_individual():
normalize_individual = iu.NormalizeIndividual()
image = torch.randn((3, 100, 100))
Expand Down
3 changes: 2 additions & 1 deletion tutorials/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
*.h5
*.tp
*.tp
*.pt
149 changes: 0 additions & 149 deletions tutorials/first_tutorial.ipynb

This file was deleted.

562 changes: 0 additions & 562 deletions tutorials/hsp90_analysis.ipynb

This file was deleted.

Binary file removed tutorials/hsp90_models.npy
Binary file not shown.
Binary file removed tutorials/image_anim.npy
Binary file not shown.
15 changes: 0 additions & 15 deletions tutorials/image_params_snr01_128.json

This file was deleted.

Binary file removed tutorials/quaternion_list.npy
Binary file not shown.
12 changes: 0 additions & 12 deletions tutorials/resnet18_encoder.json

This file was deleted.

11 changes: 11 additions & 0 deletions tutorials/simulation_parameters.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"N_PIXELS": 64,
"PIXEL_SIZE": 2.0,
"SIGMA": [2.0, 2.0],
"MODEL_FILE": "models.pt",
"SHIFT": 0.0,
"DEFOCUS": [2.0, 2.0],
"SNR": [0.01, 0.5],
"AMP": 0.1,
"B_FACTOR": [1.0, 1.0]
}
14 changes: 14 additions & 0 deletions tutorials/train_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"EMBEDDING": "ConvEncoder_Tutorial",
"OUT_DIM": 128,
"NUM_TRANSFORM": 5,
"NUM_HIDDEN_FLOW": 5,
"HIDDEN_DIM_FLOW": 128,
"MODEL": "NSF",
"LEARNING_RATE": 0.0003,
"CLIP_GRADIENT": 5.0,
"THETA_SHIFT": 50,
"THETA_SCALE": 50,
"BATCH_SIZE": 32
}

Loading

0 comments on commit 7caf5bf

Please sign in to comment.