Skip to content

Commit

Permalink
added random batchinng for experimental data
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jun 14, 2024
1 parent 84af9ad commit 9a697d8
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
1 change: 0 additions & 1 deletion src/cryo_sbi/inference/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,4 @@ def forward(self, theta: torch.Tensor, x: torch.Tensor, x_obs: torch.Tensor) ->
)
log_p = self.estimator(theta, x)

print(log_p.mean(), summary_stats_regularization)
return -log_p.mean() + summary_stats_regularization
5 changes: 3 additions & 2 deletions src/cryo_sbi/inference/train_npe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def npe_train_no_saving(
)

loss = NPERobustStatsLoss(estimator, gamma)
experimental_particles = torch.load(experimental_particles, map_location=device)
experimental_particles = torch.load(experimental_particles)

optimizer = optim.AdamW(
estimator.parameters(), lr=train_config["LEARNING_RATE"], weight_decay=0.001
Expand Down Expand Up @@ -160,12 +160,13 @@ def npe_train_no_saving(
indices.split(train_config["BATCH_SIZE"]),
images.split(train_config["BATCH_SIZE"]),
):
random_indices = torch.randperm(experimental_particles.size(0))[:train_config["BATCH_SIZE"]]
losses.append(
step(
loss(
_indices.to(device, non_blocking=True),
_images.to(device, non_blocking=True),
experimental_particles.to(device, non_blocking=True),
experimental_particles[random_indices].to(device, non_blocking=True),
)
)
)
Expand Down

0 comments on commit 9a697d8

Please sign in to comment.