diff --git a/experiments/6wxb/image_params_mixed_training.json b/experiments/6wxb/image_params_mixed_training.json index fb1faa8..964efe7 100644 --- a/experiments/6wxb/image_params_mixed_training.json +++ b/experiments/6wxb/image_params_mixed_training.json @@ -1,9 +1,9 @@ { "N_PIXELS": 128, "PIXEL_SIZE": 2.06, - "SIGMA": [4.0, 15.0], + "SIGMA": [0.5, 5.0], "MODEL_FILE": "../data/protein_models/6wxb_mixed_models.npy", - "ROTATIONS": false, + "ROTATIONS": true, "SHIFT": true, "CTF": true, "NOISE": true, diff --git a/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py b/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py index d0eadc6..3855126 100644 --- a/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py +++ b/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py @@ -10,6 +10,7 @@ from cryo_sbi.wpa_simulator.padding import pad_image from cryo_sbi.wpa_simulator.shift import apply_no_shift, apply_random_shift from cryo_sbi.wpa_simulator.validate_image_config import check_params +from cryo_sbi.wpa_simulator.implicit_water import add_noise_field class CryoEmSimulator: diff --git a/src/cryo_sbi/wpa_simulator/implicit_water.py b/src/cryo_sbi/wpa_simulator/implicit_water.py new file mode 100644 index 0000000..131b716 --- /dev/null +++ b/src/cryo_sbi/wpa_simulator/implicit_water.py @@ -0,0 +1,47 @@ +import torch + + +def gen_noise_field(num_pixels, num_sin_func=10, max_intensity=1e-3): + """Generate a noise field with a given number of sinusoidal functions. + + Args: + num_pixels (int): Number of pixels in the noise field. + num_sin_func (int, optional): Number of sinusoidal functions. Defaults to 10. + max_intensity (float, optional): Maximum intensity of the noise field. Defaults to 1e-3. + + Returns: + torch.Tensor: Noise field. + """ + + x = torch.linspace(-100, 100, num_pixels) + y = torch.linspace(-100, 100, num_pixels) + xx, yy = torch.meshgrid(x, y) + + b = 0.6 * (torch.rand((num_sin_func, 2)) - 0.5) + c = 2 * torch.pi * (torch.rand(num_sin_func, 2) - 0.5) + + noise_field = torch.zeros_like(xx, dtype=torch.double) + for i in range(num_sin_func): + noise_field += torch.sin(b[i, 0] * xx + c[i, 0]) * torch.sin( + b[i, 1] * yy + c[i, 1] + ) + noise_field = max_intensity * (noise_field / noise_field.max()) + return noise_field + + +def add_noise_field(image, min_intensity): + """Add a noise field to an image. + + Args: + image (torch.Tensor): Image of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels). + min_intensity (float): Minimum intensity of the image. + + Returns: + torch.Tensor: Image with noise field. + """ + + noise_field = gen_noise_field(image.shape[0], max_intensity=1e-12) + idx_replace = image < min_intensity + image[idx_replace] = noise_field[idx_replace] + + return image