Skip to content

Commit

Permalink
include all params in training
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jun 12, 2024
1 parent 83d90b9 commit 23f1ee4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/cryo_sbi/inference/models/estimator_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
super().__init__()

self.npe = NPE(
1,
8,
output_embedding_dim,
transforms=num_transforms,
build=flow,
Expand Down Expand Up @@ -144,4 +144,4 @@ def sample(self, x: torch.Tensor, shape=(1,)) -> torch.Tensor:
"""

samples_standardized = self.flow(x).sample(shape)
return self.standardize.transform(samples_standardized)
return self.standardize.transform(samples_standardized)
20 changes: 15 additions & 5 deletions src/cryo_sbi/inference/train_npe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def npe_train_no_saving(
step = GDStep(optimizer, clip=train_config["CLIP_GRADIENT"])
mean_loss = []

print("Training neural netowrk:")
print("Training neural netowrk with all params:")
estimator.train()
with tqdm(range(epochs), unit="epoch") as tq:
for epoch in tq:
Expand Down Expand Up @@ -151,14 +151,24 @@ def npe_train_no_saving(
num_pixels,
pixel_size,
)
for _indices, _images in zip(
indices.split(train_config["BATCH_SIZE"]),
images.split(train_config["BATCH_SIZE"]),
flow_params = torch.cat(
(
indices,
quaternions,
defocus.reshape(-1, 1),
b_factor.reshape(-1, 1),
snr.reshape(-1, 1),
),
dim=1,
)
for _flow_params, _images in zip(
flow_params.split(train_config["BATCH_SIZE"]),
images.split(train_config["BATCH_SIZE"])
):
losses.append(
step(
loss(
_indices.to(device, non_blocking=True),
_flow_params.to(device, non_blocking=True),
_images.to(device, non_blocking=True),
)
)
Expand Down

0 comments on commit 23f1ee4

Please sign in to comment.