From 9a697d8013653b8a3c716269c5f0b8f755fd3257 Mon Sep 17 00:00:00 2001 From: Dingel321 Date: Fri, 14 Jun 2024 15:16:57 -0400 Subject: [PATCH] added random batchinng for experimental data --- src/cryo_sbi/inference/losses.py | 1 - src/cryo_sbi/inference/train_npe_model.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/cryo_sbi/inference/losses.py b/src/cryo_sbi/inference/losses.py index d65e8ff..3d13751 100644 --- a/src/cryo_sbi/inference/losses.py +++ b/src/cryo_sbi/inference/losses.py @@ -55,5 +55,4 @@ def forward(self, theta: torch.Tensor, x: torch.Tensor, x_obs: torch.Tensor) -> ) log_p = self.estimator(theta, x) - print(log_p.mean(), summary_stats_regularization) return -log_p.mean() + summary_stats_regularization \ No newline at end of file diff --git a/src/cryo_sbi/inference/train_npe_model.py b/src/cryo_sbi/inference/train_npe_model.py index 5666af9..052ef71 100644 --- a/src/cryo_sbi/inference/train_npe_model.py +++ b/src/cryo_sbi/inference/train_npe_model.py @@ -119,7 +119,7 @@ def npe_train_no_saving( ) loss = NPERobustStatsLoss(estimator, gamma) - experimental_particles = torch.load(experimental_particles, map_location=device) + experimental_particles = torch.load(experimental_particles) optimizer = optim.AdamW( estimator.parameters(), lr=train_config["LEARNING_RATE"], weight_decay=0.001 @@ -160,12 +160,13 @@ def npe_train_no_saving( indices.split(train_config["BATCH_SIZE"]), images.split(train_config["BATCH_SIZE"]), ): + random_indices = torch.randperm(experimental_particles.size(0))[:train_config["BATCH_SIZE"]] losses.append( step( loss( _indices.to(device, non_blocking=True), _images.to(device, non_blocking=True), - experimental_particles.to(device, non_blocking=True), + experimental_particles[random_indices].to(device, non_blocking=True), ) ) )