Skip to content

Commit

Permalink
fixed iss 52 in simulator. Updated test for simulator
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Feb 6, 2024
1 parent dc2cea1 commit ebfaf0e
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/cryo_sbi/wpa_simulator/cryo_em_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=No
assert (
indices.ndim == 2
), "Indices are not a 2D tensor, converting to 2D tensor. With shape (batch_size, 1)."
indices = torch.tensor(indices, dtype=torch.float32)
parameters[0] = indices

images = []
if batch_size is None:
Expand Down
4 changes: 2 additions & 2 deletions tests/config_files/image_params_testing.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"N_PIXELS": 128,
"N_PIXELS": 64,
"PIXEL_SIZE": 2.06,
"SIGMA": [0.5, 5.0],
"MODEL_FILE": "../models/hemagglutinin_models.pt",
"MODEL_FILE": "tests/models/hsp90_models.pt",
"SHIFT": 20.0,
"DEFOCUS": [1.5, 3.5],
"SNR": [0.05, 0.05],
Expand Down
Binary file modified tests/models/hsp90_models.pt
Binary file not shown.
19 changes: 18 additions & 1 deletion tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import json

from cryo_sbi.wpa_simulator.cryo_em_simulator import cryo_em_simulator
from cryo_sbi.wpa_simulator.cryo_em_simulator import cryo_em_simulator, CryoEmSimulator
from cryo_sbi.wpa_simulator.ctf import apply_ctf
from cryo_sbi.wpa_simulator.image_generation import (
project_density,
Expand Down Expand Up @@ -79,3 +79,20 @@ def test_get_snr(noise_std, num_images):
assert torch.allclose(
snr.flatten(), noise_std * torch.ones(images.shape[0]), atol=1e-01
), "SNR is not correct"

@pytest.mark.parametrize(("num_images"), [2, 3, 10])
def test_simulator_default_settings(num_images):
sim = CryoEmSimulator("tests/config_files/image_params_testing.json")
images = sim.simulate(num_images)
assert images.shape == torch.Size([num_images, 64, 64])


def test_simulator_custom_indices():
sim = CryoEmSimulator("tests/config_files/image_params_testing.json")
test_indices = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.float32).reshape(-1, 1)
images, parameters = sim.simulate(6, indices=test_indices, return_parameters=True)

assert (parameters[0] == test_indices).all().item()
assert images.shape == torch.Size([6, 64, 64])


0 comments on commit ebfaf0e

Please sign in to comment.