From 07db6f324cd4ca5c8b295e854891e4924ab31246 Mon Sep 17 00:00:00 2001 From: Dingel321 Date: Fri, 26 Jan 2024 10:58:36 +0100 Subject: [PATCH] fixed bug in noise calculation --- src/cryo_sbi/wpa_simulator/noise.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/cryo_sbi/wpa_simulator/noise.py b/src/cryo_sbi/wpa_simulator/noise.py index 8781d94..d9e92bb 100644 --- a/src/cryo_sbi/wpa_simulator/noise.py +++ b/src/cryo_sbi/wpa_simulator/noise.py @@ -37,8 +37,8 @@ def get_snr(images, snr): images[:, mask], dim=[-1] ) # images are not centered at 0, so std is not the same as power assert signal_power.shape[0] == images.shape[0] - noise_power = signal_power / torch.sqrt(torch.pow(snr, torch.tensor(10))) - + noise_power = signal_power.reshape(-1, 1, 1) / torch.sqrt(torch.pow(torch.tensor(10), snr)) + print(torch.pow(snr, torch.tensor(10))) return noise_power @@ -56,10 +56,11 @@ def add_noise(image: torch.Tensor, snr, seed=None) -> torch.Tensor: """ if seed is not None: - torch.manual_seed(seed) # + torch.manual_seed(seed) noise_power = get_snr(image, snr) noise = torch.randn_like(image, device=image.device) + print(noise.shape, noise_power.shape, image.shape, snr.shape) noise = noise * noise_power.reshape(-1, 1, 1) image_noise = image + noise