From e1568247d444f0a43117a9949f7654c68cf0f8c2 Mon Sep 17 00:00:00 2001 From: Dingel321 Date: Thu, 7 Mar 2024 15:58:45 +0100 Subject: [PATCH] formatted code --- src/cryo_sbi/inference/priors.py | 4 +- src/cryo_sbi/utils/estimator_utils.py | 2 +- src/cryo_sbi/utils/image_utils.py | 6 +- src/cryo_sbi/utils/visualize_models.py | 28 ++++---- src/cryo_sbi/wpa_simulator/ctf.py | 4 +- .../wpa_simulator/image_generation.py | 4 +- src/cryo_sbi/wpa_simulator/noise.py | 2 +- tests/test_image_utils.py | 6 +- tests/test_visualize_models.py | 14 ++-- tutorials/tutorial.ipynb | 66 +++++++++++++------ 10 files changed, 86 insertions(+), 50 deletions(-) diff --git a/src/cryo_sbi/inference/priors.py b/src/cryo_sbi/inference/priors.py index 6342b1f..f8fbb29 100644 --- a/src/cryo_sbi/inference/priors.py +++ b/src/cryo_sbi/inference/priors.py @@ -15,7 +15,7 @@ def gen_quat() -> torch.Tensor: count = 0 while count < 1: quat = 2 * torch.rand(size=(4,)) - 1 - norm = torch.sqrt(torch.sum(quat ** 2)) + norm = torch.sqrt(torch.sum(quat**2)) if 0.2 <= norm <= 1.0: quat /= norm count += 1 @@ -207,7 +207,7 @@ class PriorLoader(DataLoader): def __init__( self, prior: Distribution, - batch_size: int = 2 ** 8, # 256 + batch_size: int = 2**8, # 256 **kwargs, ): super().__init__( diff --git a/src/cryo_sbi/utils/estimator_utils.py b/src/cryo_sbi/utils/estimator_utils.py index acefe40..44523be 100644 --- a/src/cryo_sbi/utils/estimator_utils.py +++ b/src/cryo_sbi/utils/estimator_utils.py @@ -24,7 +24,7 @@ def evaluate_log_prob( Returns: torch.Tensor: The log probabilities of the images under the estimator. """ - + # batching images if necessary if images.shape[0] > batch_size and batch_size > 0: images = torch.split(images, split_size_or_sections=batch_size, dim=0) diff --git a/src/cryo_sbi/utils/image_utils.py b/src/cryo_sbi/utils/image_utils.py index be0efd5..4bab630 100644 --- a/src/cryo_sbi/utils/image_utils.py +++ b/src/cryo_sbi/utils/image_utils.py @@ -25,9 +25,9 @@ def circular_mask(n_pixels: int, radius: int, inside: bool = True) -> torch.Tens r_2d = grid[None, :] ** 2 + grid[:, None] ** 2 if inside is True: - mask = r_2d < radius ** 2 + mask = r_2d < radius**2 else: - mask = r_2d > radius ** 2 + mask = r_2d > radius**2 return mask @@ -183,7 +183,7 @@ def __init__(self, image_size: int, sigma: int): -0.5 * (image_size - 1), 0.5 * (image_size - 1), image_size ) self._r_2d = self._grid[None, :] ** 2 + self._grid[:, None] ** 2 - self._mask = torch.exp(-self._r_2d / (2 * sigma ** 2)) + self._mask = torch.exp(-self._r_2d / (2 * sigma**2)) def __call__(self, image: torch.Tensor) -> torch.Tensor: """ diff --git a/src/cryo_sbi/utils/visualize_models.py b/src/cryo_sbi/utils/visualize_models.py index c4144c6..a90d7d0 100644 --- a/src/cryo_sbi/utils/visualize_models.py +++ b/src/cryo_sbi/utils/visualize_models.py @@ -3,21 +3,28 @@ import torch -def _scatter_plot_models(model: torch.Tensor, view_angles : tuple = (30, 45), **plot_kwargs: dict) -> None: +def _scatter_plot_models( + model: torch.Tensor, view_angles: tuple = (30, 45), **plot_kwargs: dict +) -> None: fig = plt.figure() - ax = fig.add_subplot(111, projection='3d') + ax = fig.add_subplot(111, projection="3d") ax.view_init(*view_angles) ax.scatter(*model, **plot_kwargs) - ax.set_xlabel('X') - ax.set_ylabel('Y') - ax.set_zlabel('Z') + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_zlabel("Z") -def _sphere_plot_models(model: torch.Tensor, radius: float = 4, view_angles : tuple = (30, 45), **plot_kwargs: dict,) -> None: +def _sphere_plot_models( + model: torch.Tensor, + radius: float = 4, + view_angles: tuple = (30, 45), + **plot_kwargs: dict, +) -> None: fig = plt.figure() - ax = fig.add_subplot(111, projection='3d') + ax = fig.add_subplot(111, projection="3d") ax.view_init(30, 45) spheres = [] @@ -35,9 +42,9 @@ def _sphere_plot_models(model: torch.Tensor, radius: float = 4, view_angles : tu ax.plot_surface(x, y, z, **plot_kwargs) - ax.set_xlabel('X') - ax.set_ylabel('Y') - ax.set_zlabel('Z') + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_zlabel("Z") def plot_model(model: torch.Tensor, method: str = "scatter", **kwargs) -> None: @@ -70,4 +77,3 @@ def plot_model(model: torch.Tensor, method: str = "scatter", **kwargs) -> None: else: raise ValueError(f"Unknown method {method}. Use 'scatter' or 'sphere'.") - diff --git a/src/cryo_sbi/wpa_simulator/ctf.py b/src/cryo_sbi/wpa_simulator/ctf.py index 9717e67..35a4a81 100644 --- a/src/cryo_sbi/wpa_simulator/ctf.py +++ b/src/cryo_sbi/wpa_simulator/ctf.py @@ -21,7 +21,7 @@ def apply_ctf(image: torch.Tensor, defocus, b_factor, amp, pixel_size) -> torch. freq_pix_1d = torch.fft.fftfreq(num_pixels, d=pixel_size, device=image.device) x, y = torch.meshgrid(freq_pix_1d, freq_pix_1d, indexing="ij") - freq2_2d = x ** 2 + y ** 2 + freq2_2d = x**2 + y**2 freq2_2d = freq2_2d.expand(num_batch, -1, -1) imag = torch.zeros_like(freq2_2d, device=image.device) * 1j @@ -30,7 +30,7 @@ def apply_ctf(image: torch.Tensor, defocus, b_factor, amp, pixel_size) -> torch. ctf = ( -amp * torch.cos(phase * freq2_2d * 0.5) - - torch.sqrt(1 - amp ** 2) * torch.sin(phase * freq2_2d * 0.5) + - torch.sqrt(1 - amp**2) * torch.sin(phase * freq2_2d * 0.5) + imag ) ctf = ctf * env / amp diff --git a/src/cryo_sbi/wpa_simulator/image_generation.py b/src/cryo_sbi/wpa_simulator/image_generation.py index c19a18b..3e9bc00 100644 --- a/src/cryo_sbi/wpa_simulator/image_generation.py +++ b/src/cryo_sbi/wpa_simulator/image_generation.py @@ -13,7 +13,7 @@ def gen_quat() -> torch.Tensor: count = 0 while count < 1: quat = 2 * torch.rand(size=(4,)) - 1 - norm = torch.sqrt(torch.sum(quat ** 2)) + norm = torch.sqrt(torch.sum(quat**2)) if 0.2 <= norm <= 1.0: quat /= norm count += 1 @@ -72,7 +72,7 @@ def project_density( """ num_batch, _, num_atoms = coords.shape - norm = 1 / (2 * torch.pi * sigma ** 2 * num_atoms) + norm = 1 / (2 * torch.pi * sigma**2 * num_atoms) grid_min = -pixel_size * num_pixels * 0.5 grid_max = pixel_size * num_pixels * 0.5 diff --git a/src/cryo_sbi/wpa_simulator/noise.py b/src/cryo_sbi/wpa_simulator/noise.py index 013c066..a763bb2 100644 --- a/src/cryo_sbi/wpa_simulator/noise.py +++ b/src/cryo_sbi/wpa_simulator/noise.py @@ -19,7 +19,7 @@ def circular_mask(n_pixels: int, radius: int, device: str = "cpu") -> torch.Tens -0.5 * (n_pixels - 1), 0.5 * (n_pixels - 1), n_pixels, device=device ) r_2d = grid[None, :] ** 2 + grid[:, None] ** 2 - mask = r_2d < radius ** 2 + mask = r_2d < radius**2 return mask diff --git a/tests/test_image_utils.py b/tests/test_image_utils.py index dadef69..ff21436 100644 --- a/tests/test_image_utils.py +++ b/tests/test_image_utils.py @@ -11,9 +11,9 @@ def test_circular_mask(): assert inside_mask.shape == (n_pixels, n_pixels) assert outside_mask.shape == (n_pixels, n_pixels) - assert inside_mask.sum().item() == pytest.approx(radius ** 2 * 3.14159, abs=10) + assert inside_mask.sum().item() == pytest.approx(radius**2 * 3.14159, abs=10) assert outside_mask.sum().item() == pytest.approx( - n_pixels ** 2 - radius ** 2 * 3.14159, abs=10 + n_pixels**2 - radius**2 * 3.14159, abs=10 ) @@ -27,7 +27,7 @@ def test_mask_class(): masked_image = mask(image) assert masked_image.shape == (image_size, image_size) assert masked_image[inside].sum().item() == pytest.approx( - image_size ** 2 - radius ** 2 * 3.14159, abs=10 + image_size**2 - radius**2 * 3.14159, abs=10 ) diff --git a/tests/test_visualize_models.py b/tests/test_visualize_models.py index 108d433..2cc59ab 100644 --- a/tests/test_visualize_models.py +++ b/tests/test_visualize_models.py @@ -5,16 +5,22 @@ def test_plot_model_scatter(): model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) - plot_model(model, method="scatter") # No assertion, just checking if it runs without errors + plot_model( + model, method="scatter" + ) # No assertion, just checking if it runs without errors def test_plot_model_sphere(): model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) - plot_model(model, method="sphere") # No assertion, just checking if it runs without errors + plot_model( + model, method="sphere" + ) # No assertion, just checking if it runs without errors def test_plot_model_invalid_model(): - model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) # Invalid shape, should have 3 rows + model = torch.tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + ) # Invalid shape, should have 3 rows with pytest.raises(AssertionError): plot_model(model, method="scatter") @@ -22,4 +28,4 @@ def test_plot_model_invalid_model(): def test_plot_model_invalid_method(): model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) with pytest.raises(ValueError): - plot_model(model, method="invalid_method") \ No newline at end of file + plot_model(model, method="invalid_method") diff --git a/tutorials/tutorial.ipynb b/tutorials/tutorial.ipynb index d0ffc2c..59c71ef 100644 --- a/tutorials/tutorial.ipynb +++ b/tutorials/tutorial.ipynb @@ -47,9 +47,9 @@ "models = []\n", "for side_length in side_lengths:\n", " model = [\n", - " [side_length, -side_length, side_length, -side_length],\n", - " [side_length, side_length, -side_length, -side_length], \n", - " [0.0, 0.0, 0.0, 0.0]\n", + " [side_length, -side_length, side_length, -side_length],\n", + " [side_length, side_length, -side_length, -side_length],\n", + " [0.0, 0.0, 0.0, 0.0],\n", " ]\n", " models.append(model)\n", "models = torch.tensor(models)" @@ -91,14 +91,20 @@ "source": [ "fig = plt.figure()\n", "ax = fig.add_subplot()\n", - "for i, c in zip([0, 25, 50, 75, 99], ['red', 'orange', 'green', 'blue', 'purple']):\n", - " ax.scatter(models[i, 0, 0], models[i, 1, 0], s=60, color=c, label=f\"Model with angle : {i:.2f}\")\n", + "for i, c in zip([0, 25, 50, 75, 99], [\"red\", \"orange\", \"green\", \"blue\", \"purple\"]):\n", + " ax.scatter(\n", + " models[i, 0, 0],\n", + " models[i, 1, 0],\n", + " s=60,\n", + " color=c,\n", + " label=f\"Model with angle : {i:.2f}\",\n", + " )\n", " ax.scatter(models[i, 0, 1], models[i, 1, 1], s=60, color=c)\n", " ax.scatter(models[i, 0, 2], models[i, 1, 2], s=60, color=c)\n", " ax.scatter(models[i, 0, 3], models[i, 1, 3], s=60, color=c)\n", "\n", - "ax.set_xlabel('X')\n", - "ax.set_ylabel('Y')\n", + "ax.set_xlabel(\"X\")\n", + "ax.set_ylabel(\"Y\")\n", "plt.legend()" ] }, @@ -112,7 +118,7 @@ "for i in range(100):\n", " models[i] = models[i] - models[i].mean(dim=1, keepdim=True)\n", "\n", - "torch.save(models, 'models.pt')" + "torch.save(models, \"models.pt\")" ] }, { @@ -155,7 +161,9 @@ "metadata": {}, "outputs": [], "source": [ - "simulator = CryoEmSimulator(\"simulation_parameters.json\") # creating simulator with simulation parameters" + "simulator = CryoEmSimulator(\n", + " \"simulation_parameters.json\"\n", + ") # creating simulator with simulation parameters" ] }, { @@ -164,7 +172,9 @@ "metadata": {}, "outputs": [], "source": [ - "images, parameters = simulator.simulate(num_sim=5000, return_parameters=True) # simulating images and save parameters" + "images, parameters = simulator.simulate(\n", + " num_sim=5000, return_parameters=True\n", + ") # simulating images and save parameters" ] }, { @@ -173,8 +183,8 @@ "metadata": {}, "outputs": [], "source": [ - "dist = parameters[0] # extracting distance from parameters\n", - "snr = parameters[-1] # extracting snr from parameters" + "dist = parameters[0] # extracting distance from parameters\n", + "snr = parameters[-1] # extracting snr from parameters" ] }, { @@ -204,7 +214,9 @@ "fig, axes = plt.subplots(4, 4, figsize=(6, 6))\n", "for idx, ax in enumerate(axes.flatten()):\n", " ax.imshow(images[idx], vmin=-3, vmax=3, cmap=\"gray\")\n", - " ax.set_title(f\"Side: {side_lengths[dist[idx].round().long()].item():.2f}\", fontsize=10)\n", + " ax.set_title(\n", + " f\"Side: {side_lengths[dist[idx].round().long()].item():.2f}\", fontsize=10\n", + " )\n", " ax.axis(\"off\")" ] }, @@ -270,12 +282,12 @@ " \"simulation_parameters.json\",\n", " \"training_parameters.json\",\n", " 150,\n", - " \"tutorial_estimator.pt\", # name of the estimator file\n", - " \"tutorial.loss\", # name of the loss file\n", - " n_workers=4, # number of workers for data loading\n", - " device=\"cuda\", # device to use for training and simulation\n", - " saving_frequency=100, # frequency of saving the model\n", - " simulation_batch_size=160, # batch size for simulation\n", + " \"tutorial_estimator.pt\", # name of the estimator file\n", + " \"tutorial.loss\", # name of the loss file\n", + " n_workers=4, # number of workers for data loading\n", + " device=\"cuda\", # device to use for training and simulation\n", + " saving_frequency=100, # frequency of saving the model\n", + " simulation_batch_size=160, # batch size for simulation\n", ")" ] }, @@ -447,7 +459,13 @@ } ], "source": [ - "plt.scatter(latent_vecs_transformed[:, 0], latent_vecs_transformed[:, 1], c=dist, cmap=\"viridis\", s=10)\n", + "plt.scatter(\n", + " latent_vecs_transformed[:, 0],\n", + " latent_vecs_transformed[:, 1],\n", + " c=dist,\n", + " cmap=\"viridis\",\n", + " s=10,\n", + ")\n", "plt.colorbar(label=\"Side length\")\n", "plt.xlabel(\"UMAP 1\")\n", "plt.ylabel(\"UMAP 2\")" @@ -480,7 +498,13 @@ } ], "source": [ - "plt.scatter(latent_vecs_transformed[:, 0], latent_vecs_transformed[:, 1], c=snr, cmap=\"viridis\", s=10)\n", + "plt.scatter(\n", + " latent_vecs_transformed[:, 0],\n", + " latent_vecs_transformed[:, 1],\n", + " c=snr,\n", + " cmap=\"viridis\",\n", + " s=10,\n", + ")\n", "plt.colorbar(label=\"SNR\")\n", "plt.xlabel(\"UMAP 1\")\n", "plt.ylabel(\"UMAP 2\")"