From cd046a751710629fa65c2ff0c17a207d60d76b8e Mon Sep 17 00:00:00 2001 From: Dingel321 Date: Fri, 26 Jan 2024 11:13:10 +0100 Subject: [PATCH] fixed new bug in noise test --- src/cryo_sbi/wpa_simulator/noise.py | 4 ++-- tests/test_simulator.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/cryo_sbi/wpa_simulator/noise.py b/src/cryo_sbi/wpa_simulator/noise.py index d9e92bb..717d320 100644 --- a/src/cryo_sbi/wpa_simulator/noise.py +++ b/src/cryo_sbi/wpa_simulator/noise.py @@ -38,7 +38,7 @@ def get_snr(images, snr): ) # images are not centered at 0, so std is not the same as power assert signal_power.shape[0] == images.shape[0] noise_power = signal_power.reshape(-1, 1, 1) / torch.sqrt(torch.pow(torch.tensor(10), snr)) - print(torch.pow(snr, torch.tensor(10))) + return noise_power @@ -60,7 +60,7 @@ def add_noise(image: torch.Tensor, snr, seed=None) -> torch.Tensor: noise_power = get_snr(image, snr) noise = torch.randn_like(image, device=image.device) - print(noise.shape, noise_power.shape, image.shape, snr.shape) + noise = noise * noise_power.reshape(-1, 1, 1) image_noise = image + noise diff --git a/tests/test_simulator.py b/tests/test_simulator.py index f1174da..0d23e96 100644 --- a/tests/test_simulator.py +++ b/tests/test_simulator.py @@ -62,7 +62,7 @@ def test_gen_rot_matrix_batched(): @pytest.mark.parametrize( ("noise_std", "num_images"), [ - (torch.tensor([1.5]), 1), + (torch.tensor([1.5, 1]), 2), (torch.tensor([1.0, 2.0, 3.0]), 3), (torch.tensor([0.1]), 10), ], @@ -72,8 +72,8 @@ def test_get_snr(noise_std, num_images): images = noise_std.reshape(-1, 1, 1) * torch.randn(num_images, 128, 128) # Compute the SNR of the test image - snr = get_snr(images, 1.0) + snr = get_snr(images, torch.tensor([0.0])) - assert snr.shape == torch.Size([images.shape[0]]) + assert snr.shape == torch.Size([images.shape[0], 1, 1]), "SNR has wrong shape" assert isinstance(snr, torch.Tensor) - assert torch.allclose(snr, noise_std * torch.ones(images.shape[0]), atol=1e-01) + assert torch.allclose(snr.flatten(), noise_std * torch.ones(images.shape[0]), atol=1e-01), "SNR is not correct"