From 84af9adb0e0bbff4cf06abd32cc8ff583e581b6c Mon Sep 17 00:00:00 2001 From: Dingel321 Date: Thu, 13 Jun 2024 16:45:44 -0400 Subject: [PATCH] implemented first version of robust loss --- src/cryo_sbi/inference/losses.py | 59 +++++++++++++++++++++++ src/cryo_sbi/inference/train_npe_model.py | 10 +++- 2 files changed, 67 insertions(+), 2 deletions(-) create mode 100644 src/cryo_sbi/inference/losses.py diff --git a/src/cryo_sbi/inference/losses.py b/src/cryo_sbi/inference/losses.py new file mode 100644 index 0000000..d65e8ff --- /dev/null +++ b/src/cryo_sbi/inference/losses.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn + + +def kernel_matrix(x, y, l): + d = torch.cdist(x, y)**2 + + kernel = torch.exp(-(1 / (2 * l ** 2)) * d) + + return kernel + + +def mmd_unweighted(x, y, lengthscale): + """ Approximates the squared MMD between samples x_i ~ P and y_i ~ Q + """ + + m = x.shape[0] + n = y.shape[0] + + z = torch.cat((x, y), dim=0) + + K = kernel_matrix(z, z, lengthscale) + + kxx = K[0:m, 0:m] + kyy = K[m:(m + n), m:(m + n)] + kxy = K[0:m, m:(m + n)] + + return (1 / m ** 2) * torch.sum(kxx) - (2 / (m * n)) * torch.sum(kxy) + (1 / n ** 2) * torch.sum(kyy) + + +def median_heuristic(y): + a = torch.cdist(y, y)**2 + return torch.sqrt(torch.median(a / 2)) + + +class NPERobustStatsLoss(nn.Module): + + def __init__(self, estimator: nn.Module, gamma: float): + super().__init__() + + self.estimator = estimator + self.gamma = gamma + + def forward(self, theta: torch.Tensor, x: torch.Tensor, x_obs: torch.Tensor) -> torch.Tensor: + + self.estimator.embedding.eval() + latent_vecs_x = self.estimator.embedding(x) + latent_vecs_x_obs = self.estimator.embedding(x_obs) + self.estimator.embedding.train() + + summary_stats_regularization = self.gamma * mmd_unweighted( + latent_vecs_x, + latent_vecs_x_obs, + median_heuristic(x) + ) + log_p = self.estimator(theta, x) + + print(log_p.mean(), summary_stats_regularization) + return -log_p.mean() + summary_stats_regularization \ No newline at end of file diff --git a/src/cryo_sbi/inference/train_npe_model.py b/src/cryo_sbi/inference/train_npe_model.py index 325aa26..5666af9 100644 --- a/src/cryo_sbi/inference/train_npe_model.py +++ b/src/cryo_sbi/inference/train_npe_model.py @@ -18,6 +18,7 @@ from cryo_sbi.wpa_simulator.validate_image_config import check_image_params from cryo_sbi.inference.validate_train_config import check_train_params import cryo_sbi.utils.image_utils as img_utils +from cryo_sbi.inference.losses import NPERobustStatsLoss def load_model( @@ -51,11 +52,13 @@ def npe_train_no_saving( estimator_file: str, loss_file: str, train_from_checkpoint: bool = False, - model_state_dict: Union[str, None] = None, + model_state_dict: Union[str, None] = False, n_workers: int = 1, device: str = "cpu", saving_frequency: int = 20, simulation_batch_size: int = 1024, + gamma: float = 1.0, + experimental_particles: Union[str, None] = None, ) -> None: """ Train NPE model by simulating training data on the fly. @@ -115,7 +118,9 @@ def npe_train_no_saving( train_config, model_state_dict, device, train_from_checkpoint ) - loss = NPELoss(estimator) + loss = NPERobustStatsLoss(estimator, gamma) + experimental_particles = torch.load(experimental_particles, map_location=device) + optimizer = optim.AdamW( estimator.parameters(), lr=train_config["LEARNING_RATE"], weight_decay=0.001 ) @@ -160,6 +165,7 @@ def npe_train_no_saving( loss( _indices.to(device, non_blocking=True), _images.to(device, non_blocking=True), + experimental_particles.to(device, non_blocking=True), ) ) )