Skip to content

Commit

Permalink
implemented first version of robust loss
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jun 13, 2024
1 parent 83d90b9 commit 84af9ad
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 2 deletions.
59 changes: 59 additions & 0 deletions src/cryo_sbi/inference/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
import torch.nn as nn


def kernel_matrix(x, y, l):
d = torch.cdist(x, y)**2

kernel = torch.exp(-(1 / (2 * l ** 2)) * d)

return kernel


def mmd_unweighted(x, y, lengthscale):
""" Approximates the squared MMD between samples x_i ~ P and y_i ~ Q
"""

m = x.shape[0]
n = y.shape[0]

z = torch.cat((x, y), dim=0)

K = kernel_matrix(z, z, lengthscale)

kxx = K[0:m, 0:m]
kyy = K[m:(m + n), m:(m + n)]
kxy = K[0:m, m:(m + n)]

return (1 / m ** 2) * torch.sum(kxx) - (2 / (m * n)) * torch.sum(kxy) + (1 / n ** 2) * torch.sum(kyy)


def median_heuristic(y):
a = torch.cdist(y, y)**2
return torch.sqrt(torch.median(a / 2))


class NPERobustStatsLoss(nn.Module):

def __init__(self, estimator: nn.Module, gamma: float):
super().__init__()

self.estimator = estimator
self.gamma = gamma

def forward(self, theta: torch.Tensor, x: torch.Tensor, x_obs: torch.Tensor) -> torch.Tensor:

self.estimator.embedding.eval()
latent_vecs_x = self.estimator.embedding(x)
latent_vecs_x_obs = self.estimator.embedding(x_obs)
self.estimator.embedding.train()

summary_stats_regularization = self.gamma * mmd_unweighted(
latent_vecs_x,
latent_vecs_x_obs,
median_heuristic(x)
)
log_p = self.estimator(theta, x)

print(log_p.mean(), summary_stats_regularization)
return -log_p.mean() + summary_stats_regularization
10 changes: 8 additions & 2 deletions src/cryo_sbi/inference/train_npe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from cryo_sbi.wpa_simulator.validate_image_config import check_image_params
from cryo_sbi.inference.validate_train_config import check_train_params
import cryo_sbi.utils.image_utils as img_utils
from cryo_sbi.inference.losses import NPERobustStatsLoss


def load_model(
Expand Down Expand Up @@ -51,11 +52,13 @@ def npe_train_no_saving(
estimator_file: str,
loss_file: str,
train_from_checkpoint: bool = False,
model_state_dict: Union[str, None] = None,
model_state_dict: Union[str, None] = False,
n_workers: int = 1,
device: str = "cpu",
saving_frequency: int = 20,
simulation_batch_size: int = 1024,
gamma: float = 1.0,
experimental_particles: Union[str, None] = None,
) -> None:
"""
Train NPE model by simulating training data on the fly.
Expand Down Expand Up @@ -115,7 +118,9 @@ def npe_train_no_saving(
train_config, model_state_dict, device, train_from_checkpoint
)

loss = NPELoss(estimator)
loss = NPERobustStatsLoss(estimator, gamma)
experimental_particles = torch.load(experimental_particles, map_location=device)

optimizer = optim.AdamW(
estimator.parameters(), lr=train_config["LEARNING_RATE"], weight_decay=0.001
)
Expand Down Expand Up @@ -160,6 +165,7 @@ def npe_train_no_saving(
loss(
_indices.to(device, non_blocking=True),
_images.to(device, non_blocking=True),
experimental_particles.to(device, non_blocking=True),
)
)
)
Expand Down

0 comments on commit 84af9ad

Please sign in to comment.