From 5a80e8dad416fdf94d27645e29161414a0fb94a8 Mon Sep 17 00:00:00 2001 From: Dingel321 Date: Fri, 26 Jan 2024 09:39:39 +0100 Subject: [PATCH] fixed bug in get_snr --- src/cryo_sbi/wpa_simulator/noise.py | 5 +++-- tests/test_simulator.py | 18 +++++++++++++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/cryo_sbi/wpa_simulator/noise.py b/src/cryo_sbi/wpa_simulator/noise.py index 1a8539c..8781d94 100644 --- a/src/cryo_sbi/wpa_simulator/noise.py +++ b/src/cryo_sbi/wpa_simulator/noise.py @@ -34,9 +34,10 @@ def get_snr(images, snr): device=images.device, ) signal_power = torch.std( - images[:, mask], dim=[-2, -1] + images[:, mask], dim=[-1] ) # images are not centered at 0, so std is not the same as power - noise_power = signal_power / torch.sqrt(torch.pow(10, snr)) + assert signal_power.shape[0] == images.shape[0] + noise_power = signal_power / torch.sqrt(torch.pow(snr, torch.tensor(10))) return noise_power diff --git a/tests/test_simulator.py b/tests/test_simulator.py index a6f3195..5e1d0d6 100644 --- a/tests/test_simulator.py +++ b/tests/test_simulator.py @@ -54,4 +54,20 @@ def test_gen_rot_matrix_batched(): assert rot_matrix.shape == torch.Size([3, 3, 3]) assert isinstance(rot_matrix, torch.Tensor) - assert torch.allclose(rot_matrix, torch.eye(3).repeat(3, 1, 1)) \ No newline at end of file + assert torch.allclose(rot_matrix, torch.eye(3).repeat(3, 1, 1)) + + +@pytest.mark.parametrize( + ("noise_std", "num_images"), + [(torch.tensor([1.5]), 1), (torch.tensor([1.0, 2.0, 3.0]), 3), (torch.tensor([0.1]), 10)], +) +def test_get_snr(noise_std, num_images): + # Create a test image + images = noise_std.reshape(-1, 1, 1) * torch.randn(num_images, 128, 128) + + # Compute the SNR of the test image + snr = get_snr(images, 1.0) + + assert snr.shape == torch.Size([images.shape[0]]) + assert isinstance(snr, torch.Tensor) + assert torch.allclose(snr, noise_std * torch.ones(images.shape[0]), atol=1e-01) \ No newline at end of file