diff --git a/pyproject.toml b/pyproject.toml index 7890db8..f9a4621 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,8 @@ dependencies = [ "torch", "numpy", "matplotlib", - "scipy" + "scipy", + "torchvision" ] #dynamic = ["version"] diff --git a/src/cryo_sbi/inference/NPE_train_without_saving.py b/src/cryo_sbi/inference/NPE_train_without_saving.py index b85a8b6..0e10d72 100644 --- a/src/cryo_sbi/inference/NPE_train_without_saving.py +++ b/src/cryo_sbi/inference/NPE_train_without_saving.py @@ -113,5 +113,5 @@ def npe_train_no_saving( loss_file=args.loss_file, train_from_checkpoint=args.train_from_checkpoint, state_dict_file=args.state_dict_file, - n_workers=args.n_workers + n_workers=args.n_workers, ) diff --git a/src/cryo_sbi/inference/NRE_train_without_saving.py b/src/cryo_sbi/inference/NRE_train_without_saving.py index c2176ea..ca84711 100644 --- a/src/cryo_sbi/inference/NRE_train_without_saving.py +++ b/src/cryo_sbi/inference/NRE_train_without_saving.py @@ -22,7 +22,7 @@ def nre_train_no_saving( loss_file, train_from_checkpoint=False, model_state_dict=None, - n_workers=1 + n_workers=1, ): cryo_simulator = CryoEmSimulator(image_config) @@ -110,6 +110,5 @@ def nre_train_no_saving( loss_file=args.loss_file, train_from_checkpoint=args.train_from_checkpoint, state_dict_file=args.state_dict_file, - n_workers=args.n_workers + n_workers=args.n_workers, ) - diff --git a/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py b/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py index fee6760..ce08ad5 100644 --- a/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py +++ b/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py @@ -57,7 +57,6 @@ def max_index(self): return len(self.models) - 1 def simulator(self, index, seed=None): - # if seed is not None: # torch.manual_seed(seed) diff --git a/src/cryo_sbi/wpa_simulator/noise.py b/src/cryo_sbi/wpa_simulator/noise.py index 092907a..01c7ecf 100644 --- a/src/cryo_sbi/wpa_simulator/noise.py +++ b/src/cryo_sbi/wpa_simulator/noise.py @@ -11,7 +11,6 @@ def circular_mask(n_pixels, radius): def add_noise(image, image_params, seed=None): - if seed is not None: torch.manual_seed(seed) @@ -27,12 +26,12 @@ def add_noise(image, image_params, seed=None): ) else: - raise ValueError( - "SNR should be a single value or a list of [min_snr, max_snr]" - ) + raise ValueError("SNR should be a single value or a list of [min_snr, max_snr]") noise_power = signal_power / np.sqrt(snr) - image_noise = image + torch.distributions.normal.Normal(0, noise_power).sample(image.shape) + image_noise = image + torch.distributions.normal.Normal(0, noise_power).sample( + image.shape + ) return image_noise diff --git a/src/cryo_sbi/wpa_simulator/padding.py b/src/cryo_sbi/wpa_simulator/padding.py index e7ea6ec..d0b6cc3 100644 --- a/src/cryo_sbi/wpa_simulator/padding.py +++ b/src/cryo_sbi/wpa_simulator/padding.py @@ -4,7 +4,6 @@ def pad_image(image, image_params): - pad_width = int(np.ceil(image_params["N_PIXELS"] * 0.1)) + 1 padder = ConstantPad2d(pad_width, 0.0) padded_image = padder(image) diff --git a/src/cryo_sbi/wpa_simulator/shift.py b/src/cryo_sbi/wpa_simulator/shift.py index 214819e..77ef168 100644 --- a/src/cryo_sbi/wpa_simulator/shift.py +++ b/src/cryo_sbi/wpa_simulator/shift.py @@ -1,13 +1,14 @@ import torch import numpy as np + def apply_random_shift(padded_image, image_params, seed=None): if seed is not None: torch.manual_seed(seed) max_shift = int(np.round(image_params["N_PIXELS"] * 0.1)) - shift_x = int(torch.randint(low=-max_shift, high=max_shift+1, size=(1,))) - shift_y = int(torch.randint(low=-max_shift, high=max_shift+1, size=(1,))) + shift_x = int(torch.randint(low=-max_shift, high=max_shift + 1, size=(1,))) + shift_y = int(torch.randint(low=-max_shift, high=max_shift + 1, size=(1,))) pad_width = int(np.ceil(image_params["N_PIXELS"] * 0.1)) + 1 diff --git a/tests/test_simulator.py b/tests/test_simulator.py index e29a8a5..59345de 100644 --- a/tests/test_simulator.py +++ b/tests/test_simulator.py @@ -12,50 +12,54 @@ from cryo_sbi.wpa_simulator.validate_image_config import check_params from cryo_sbi import CryoEmSimulator + def _get_config(): config = json.load(open("tests/image_params_testing.json")) check_params(config) return config + def test_padding(): image_params = _get_config() pad_width = int(np.ceil(image_params["N_PIXELS"] * 0.1)) + 1 image = torch.zeros((image_params["N_PIXELS"], image_params["N_PIXELS"])) padded_image = pad_image(image, image_params) - for size in padded_image.shape: - assert size == pad_width*2 + image_params["N_PIXELS"] + for size in padded_image.shape: + assert size == pad_width * 2 + image_params["N_PIXELS"] return + def test_shift_size(): image_params = _get_config() image = torch.zeros((image_params["N_PIXELS"], image_params["N_PIXELS"])) padded_image = pad_image(image, image_params) shifted_image = apply_random_shift(padded_image, image_params) - for size in shifted_image.shape: + for size in shifted_image.shape: assert size == image_params["N_PIXELS"] return + def test_shift_bias(): image_params = _get_config() - - x_0 = image_params["N_PIXELS"]//2 - y_0 = image_params["N_PIXELS"]//2 + + x_0 = image_params["N_PIXELS"] // 2 + y_0 = image_params["N_PIXELS"] // 2 image = torch.zeros((image_params["N_PIXELS"], image_params["N_PIXELS"])) image[x_0, y_0] = 1 - image[x_0-1, y_0] = 1 - image[x_0, y_0-1] = 1 - image[x_0-1, y_0-1] = 1 - + image[x_0 - 1, y_0] = 1 + image[x_0, y_0 - 1] = 1 + image[x_0 - 1, y_0 - 1] = 1 + padded_image = pad_image(image, image_params) shifted_image = torch.zeros_like(image) for _ in range(10000): shifted_image = shifted_image + apply_random_shift(padded_image, image_params) - + indices_x, indices_y = np.where(shifted_image >= 1) assert np.mean(indices_x) == image_params["N_PIXELS"] / 2 - 0.5 @@ -63,18 +67,20 @@ def test_shift_bias(): return + def test_no_shift(): image_params = _get_config() image = torch.zeros((image_params["N_PIXELS"], image_params["N_PIXELS"])) padded_image = pad_image(image, image_params) shifted_image = apply_no_shift(padded_image, image_params) - for size in shifted_image.shape: + for size in shifted_image.shape: assert size == image_params["N_PIXELS"] assert torch.allclose(image, shifted_image) return + def test_normalization(): image_params = _get_config() img_shape = (image_params["N_PIXELS"], image_params["N_PIXELS"]) @@ -85,6 +91,7 @@ def test_normalization(): assert torch.allclose(torch.std(gnormed_image), torch.tensor(1.0), atol=1e-3) return + def test_noise(): image_params = _get_config() N = 10000 @@ -99,14 +106,16 @@ def test_noise(): assert torch.allclose(torch.mean(stds), torch.tensor(1.0), atol=1e-3) return -#def test_ctf(): + +# def test_ctf(): + def test_simulation(): simul = CryoEmSimulator("tests/image_params_testing.json") image_sim = simul.simulator(index=torch.tensor(0), seed=0) image_params = _get_config() - model = np.load(image_params["MODEL_FILE"])[0,0] + model = np.load(image_params["MODEL_FILE"])[0, 0] image = gen_img(model, image_params) image = pad_image(image, image_params) ctf = calc_ctf(image_params)