Skip to content

Commit

Permalink
fixed bug in get_snr
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jan 26, 2024
1 parent a19e388 commit 5a80e8d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
5 changes: 3 additions & 2 deletions src/cryo_sbi/wpa_simulator/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ def get_snr(images, snr):
device=images.device,
)
signal_power = torch.std(
images[:, mask], dim=[-2, -1]
images[:, mask], dim=[-1]
) # images are not centered at 0, so std is not the same as power
noise_power = signal_power / torch.sqrt(torch.pow(10, snr))
assert signal_power.shape[0] == images.shape[0]
noise_power = signal_power / torch.sqrt(torch.pow(snr, torch.tensor(10)))

return noise_power

Expand Down
18 changes: 17 additions & 1 deletion tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,20 @@ def test_gen_rot_matrix_batched():

assert rot_matrix.shape == torch.Size([3, 3, 3])
assert isinstance(rot_matrix, torch.Tensor)
assert torch.allclose(rot_matrix, torch.eye(3).repeat(3, 1, 1))
assert torch.allclose(rot_matrix, torch.eye(3).repeat(3, 1, 1))


@pytest.mark.parametrize(
("noise_std", "num_images"),
[(torch.tensor([1.5]), 1), (torch.tensor([1.0, 2.0, 3.0]), 3), (torch.tensor([0.1]), 10)],
)
def test_get_snr(noise_std, num_images):
# Create a test image
images = noise_std.reshape(-1, 1, 1) * torch.randn(num_images, 128, 128)

# Compute the SNR of the test image
snr = get_snr(images, 1.0)

assert snr.shape == torch.Size([images.shape[0]])
assert isinstance(snr, torch.Tensor)
assert torch.allclose(snr, noise_std * torch.ones(images.shape[0]), atol=1e-01)

0 comments on commit 5a80e8d

Please sign in to comment.