Skip to content

Commit

Permalink
formatted code
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Mar 7, 2024
1 parent 0f8be74 commit e156824
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 50 deletions.
4 changes: 2 additions & 2 deletions src/cryo_sbi/inference/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def gen_quat() -> torch.Tensor:
count = 0
while count < 1:
quat = 2 * torch.rand(size=(4,)) - 1
norm = torch.sqrt(torch.sum(quat ** 2))
norm = torch.sqrt(torch.sum(quat**2))
if 0.2 <= norm <= 1.0:
quat /= norm
count += 1
Expand Down Expand Up @@ -207,7 +207,7 @@ class PriorLoader(DataLoader):
def __init__(
self,
prior: Distribution,
batch_size: int = 2 ** 8, # 256
batch_size: int = 2**8, # 256
**kwargs,
):
super().__init__(
Expand Down
2 changes: 1 addition & 1 deletion src/cryo_sbi/utils/estimator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def evaluate_log_prob(
Returns:
torch.Tensor: The log probabilities of the images under the estimator.
"""

# batching images if necessary
if images.shape[0] > batch_size and batch_size > 0:
images = torch.split(images, split_size_or_sections=batch_size, dim=0)
Expand Down
6 changes: 3 additions & 3 deletions src/cryo_sbi/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def circular_mask(n_pixels: int, radius: int, inside: bool = True) -> torch.Tens
r_2d = grid[None, :] ** 2 + grid[:, None] ** 2

if inside is True:
mask = r_2d < radius ** 2
mask = r_2d < radius**2
else:
mask = r_2d > radius ** 2
mask = r_2d > radius**2

return mask

Expand Down Expand Up @@ -183,7 +183,7 @@ def __init__(self, image_size: int, sigma: int):
-0.5 * (image_size - 1), 0.5 * (image_size - 1), image_size
)
self._r_2d = self._grid[None, :] ** 2 + self._grid[:, None] ** 2
self._mask = torch.exp(-self._r_2d / (2 * sigma ** 2))
self._mask = torch.exp(-self._r_2d / (2 * sigma**2))

def __call__(self, image: torch.Tensor) -> torch.Tensor:
"""
Expand Down
28 changes: 17 additions & 11 deletions src/cryo_sbi/utils/visualize_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,28 @@
import torch


def _scatter_plot_models(model: torch.Tensor, view_angles : tuple = (30, 45), **plot_kwargs: dict) -> None:
def _scatter_plot_models(
model: torch.Tensor, view_angles: tuple = (30, 45), **plot_kwargs: dict
) -> None:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax = fig.add_subplot(111, projection="3d")
ax.view_init(*view_angles)

ax.scatter(*model, **plot_kwargs)

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")


def _sphere_plot_models(model: torch.Tensor, radius: float = 4, view_angles : tuple = (30, 45), **plot_kwargs: dict,) -> None:
def _sphere_plot_models(
model: torch.Tensor,
radius: float = 4,
view_angles: tuple = (30, 45),
**plot_kwargs: dict,
) -> None:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax = fig.add_subplot(111, projection="3d")
ax.view_init(30, 45)

spheres = []
Expand All @@ -35,9 +42,9 @@ def _sphere_plot_models(model: torch.Tensor, radius: float = 4, view_angles : tu

ax.plot_surface(x, y, z, **plot_kwargs)

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")


def plot_model(model: torch.Tensor, method: str = "scatter", **kwargs) -> None:
Expand Down Expand Up @@ -70,4 +77,3 @@ def plot_model(model: torch.Tensor, method: str = "scatter", **kwargs) -> None:

else:
raise ValueError(f"Unknown method {method}. Use 'scatter' or 'sphere'.")

4 changes: 2 additions & 2 deletions src/cryo_sbi/wpa_simulator/ctf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def apply_ctf(image: torch.Tensor, defocus, b_factor, amp, pixel_size) -> torch.
freq_pix_1d = torch.fft.fftfreq(num_pixels, d=pixel_size, device=image.device)
x, y = torch.meshgrid(freq_pix_1d, freq_pix_1d, indexing="ij")

freq2_2d = x ** 2 + y ** 2
freq2_2d = x**2 + y**2
freq2_2d = freq2_2d.expand(num_batch, -1, -1)
imag = torch.zeros_like(freq2_2d, device=image.device) * 1j

Expand All @@ -30,7 +30,7 @@ def apply_ctf(image: torch.Tensor, defocus, b_factor, amp, pixel_size) -> torch.

ctf = (
-amp * torch.cos(phase * freq2_2d * 0.5)
- torch.sqrt(1 - amp ** 2) * torch.sin(phase * freq2_2d * 0.5)
- torch.sqrt(1 - amp**2) * torch.sin(phase * freq2_2d * 0.5)
+ imag
)
ctf = ctf * env / amp
Expand Down
4 changes: 2 additions & 2 deletions src/cryo_sbi/wpa_simulator/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def gen_quat() -> torch.Tensor:
count = 0
while count < 1:
quat = 2 * torch.rand(size=(4,)) - 1
norm = torch.sqrt(torch.sum(quat ** 2))
norm = torch.sqrt(torch.sum(quat**2))
if 0.2 <= norm <= 1.0:
quat /= norm
count += 1
Expand Down Expand Up @@ -72,7 +72,7 @@ def project_density(
"""

num_batch, _, num_atoms = coords.shape
norm = 1 / (2 * torch.pi * sigma ** 2 * num_atoms)
norm = 1 / (2 * torch.pi * sigma**2 * num_atoms)

grid_min = -pixel_size * num_pixels * 0.5
grid_max = pixel_size * num_pixels * 0.5
Expand Down
2 changes: 1 addition & 1 deletion src/cryo_sbi/wpa_simulator/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def circular_mask(n_pixels: int, radius: int, device: str = "cpu") -> torch.Tens
-0.5 * (n_pixels - 1), 0.5 * (n_pixels - 1), n_pixels, device=device
)
r_2d = grid[None, :] ** 2 + grid[:, None] ** 2
mask = r_2d < radius ** 2
mask = r_2d < radius**2

return mask

Expand Down
6 changes: 3 additions & 3 deletions tests/test_image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def test_circular_mask():

assert inside_mask.shape == (n_pixels, n_pixels)
assert outside_mask.shape == (n_pixels, n_pixels)
assert inside_mask.sum().item() == pytest.approx(radius ** 2 * 3.14159, abs=10)
assert inside_mask.sum().item() == pytest.approx(radius**2 * 3.14159, abs=10)
assert outside_mask.sum().item() == pytest.approx(
n_pixels ** 2 - radius ** 2 * 3.14159, abs=10
n_pixels**2 - radius**2 * 3.14159, abs=10
)


Expand All @@ -27,7 +27,7 @@ def test_mask_class():
masked_image = mask(image)
assert masked_image.shape == (image_size, image_size)
assert masked_image[inside].sum().item() == pytest.approx(
image_size ** 2 - radius ** 2 * 3.14159, abs=10
image_size**2 - radius**2 * 3.14159, abs=10
)


Expand Down
14 changes: 10 additions & 4 deletions tests/test_visualize_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,27 @@

def test_plot_model_scatter():
model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
plot_model(model, method="scatter") # No assertion, just checking if it runs without errors
plot_model(
model, method="scatter"
) # No assertion, just checking if it runs without errors


def test_plot_model_sphere():
model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
plot_model(model, method="sphere") # No assertion, just checking if it runs without errors
plot_model(
model, method="sphere"
) # No assertion, just checking if it runs without errors


def test_plot_model_invalid_model():
model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) # Invalid shape, should have 3 rows
model = torch.tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
) # Invalid shape, should have 3 rows
with pytest.raises(AssertionError):
plot_model(model, method="scatter")


def test_plot_model_invalid_method():
model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
with pytest.raises(ValueError):
plot_model(model, method="invalid_method")
plot_model(model, method="invalid_method")
66 changes: 45 additions & 21 deletions tutorials/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@
"models = []\n",
"for side_length in side_lengths:\n",
" model = [\n",
" [side_length, -side_length, side_length, -side_length],\n",
" [side_length, side_length, -side_length, -side_length], \n",
" [0.0, 0.0, 0.0, 0.0]\n",
" [side_length, -side_length, side_length, -side_length],\n",
" [side_length, side_length, -side_length, -side_length],\n",
" [0.0, 0.0, 0.0, 0.0],\n",
" ]\n",
" models.append(model)\n",
"models = torch.tensor(models)"
Expand Down Expand Up @@ -91,14 +91,20 @@
"source": [
"fig = plt.figure()\n",
"ax = fig.add_subplot()\n",
"for i, c in zip([0, 25, 50, 75, 99], ['red', 'orange', 'green', 'blue', 'purple']):\n",
" ax.scatter(models[i, 0, 0], models[i, 1, 0], s=60, color=c, label=f\"Model with angle : {i:.2f}\")\n",
"for i, c in zip([0, 25, 50, 75, 99], [\"red\", \"orange\", \"green\", \"blue\", \"purple\"]):\n",
" ax.scatter(\n",
" models[i, 0, 0],\n",
" models[i, 1, 0],\n",
" s=60,\n",
" color=c,\n",
" label=f\"Model with angle : {i:.2f}\",\n",
" )\n",
" ax.scatter(models[i, 0, 1], models[i, 1, 1], s=60, color=c)\n",
" ax.scatter(models[i, 0, 2], models[i, 1, 2], s=60, color=c)\n",
" ax.scatter(models[i, 0, 3], models[i, 1, 3], s=60, color=c)\n",
"\n",
"ax.set_xlabel('X')\n",
"ax.set_ylabel('Y')\n",
"ax.set_xlabel(\"X\")\n",
"ax.set_ylabel(\"Y\")\n",
"plt.legend()"
]
},
Expand All @@ -112,7 +118,7 @@
"for i in range(100):\n",
" models[i] = models[i] - models[i].mean(dim=1, keepdim=True)\n",
"\n",
"torch.save(models, 'models.pt')"
"torch.save(models, \"models.pt\")"
]
},
{
Expand Down Expand Up @@ -155,7 +161,9 @@
"metadata": {},
"outputs": [],
"source": [
"simulator = CryoEmSimulator(\"simulation_parameters.json\") # creating simulator with simulation parameters"
"simulator = CryoEmSimulator(\n",
" \"simulation_parameters.json\"\n",
") # creating simulator with simulation parameters"
]
},
{
Expand All @@ -164,7 +172,9 @@
"metadata": {},
"outputs": [],
"source": [
"images, parameters = simulator.simulate(num_sim=5000, return_parameters=True) # simulating images and save parameters"
"images, parameters = simulator.simulate(\n",
" num_sim=5000, return_parameters=True\n",
") # simulating images and save parameters"
]
},
{
Expand All @@ -173,8 +183,8 @@
"metadata": {},
"outputs": [],
"source": [
"dist = parameters[0] # extracting distance from parameters\n",
"snr = parameters[-1] # extracting snr from parameters"
"dist = parameters[0] # extracting distance from parameters\n",
"snr = parameters[-1] # extracting snr from parameters"
]
},
{
Expand Down Expand Up @@ -204,7 +214,9 @@
"fig, axes = plt.subplots(4, 4, figsize=(6, 6))\n",
"for idx, ax in enumerate(axes.flatten()):\n",
" ax.imshow(images[idx], vmin=-3, vmax=3, cmap=\"gray\")\n",
" ax.set_title(f\"Side: {side_lengths[dist[idx].round().long()].item():.2f}\", fontsize=10)\n",
" ax.set_title(\n",
" f\"Side: {side_lengths[dist[idx].round().long()].item():.2f}\", fontsize=10\n",
" )\n",
" ax.axis(\"off\")"
]
},
Expand Down Expand Up @@ -270,12 +282,12 @@
" \"simulation_parameters.json\",\n",
" \"training_parameters.json\",\n",
" 150,\n",
" \"tutorial_estimator.pt\", # name of the estimator file\n",
" \"tutorial.loss\", # name of the loss file\n",
" n_workers=4, # number of workers for data loading\n",
" device=\"cuda\", # device to use for training and simulation\n",
" saving_frequency=100, # frequency of saving the model\n",
" simulation_batch_size=160, # batch size for simulation\n",
" \"tutorial_estimator.pt\", # name of the estimator file\n",
" \"tutorial.loss\", # name of the loss file\n",
" n_workers=4, # number of workers for data loading\n",
" device=\"cuda\", # device to use for training and simulation\n",
" saving_frequency=100, # frequency of saving the model\n",
" simulation_batch_size=160, # batch size for simulation\n",
")"
]
},
Expand Down Expand Up @@ -447,7 +459,13 @@
}
],
"source": [
"plt.scatter(latent_vecs_transformed[:, 0], latent_vecs_transformed[:, 1], c=dist, cmap=\"viridis\", s=10)\n",
"plt.scatter(\n",
" latent_vecs_transformed[:, 0],\n",
" latent_vecs_transformed[:, 1],\n",
" c=dist,\n",
" cmap=\"viridis\",\n",
" s=10,\n",
")\n",
"plt.colorbar(label=\"Side length\")\n",
"plt.xlabel(\"UMAP 1\")\n",
"plt.ylabel(\"UMAP 2\")"
Expand Down Expand Up @@ -480,7 +498,13 @@
}
],
"source": [
"plt.scatter(latent_vecs_transformed[:, 0], latent_vecs_transformed[:, 1], c=snr, cmap=\"viridis\", s=10)\n",
"plt.scatter(\n",
" latent_vecs_transformed[:, 0],\n",
" latent_vecs_transformed[:, 1],\n",
" c=snr,\n",
" cmap=\"viridis\",\n",
" s=10,\n",
")\n",
"plt.colorbar(label=\"SNR\")\n",
"plt.xlabel(\"UMAP 1\")\n",
"plt.ylabel(\"UMAP 2\")"
Expand Down

0 comments on commit e156824

Please sign in to comment.