diff --git a/src/cryo_sbi/inference/models/estimator_models.py b/src/cryo_sbi/inference/models/estimator_models.py index a1ff0fe..9c5a1cc 100644 --- a/src/cryo_sbi/inference/models/estimator_models.py +++ b/src/cryo_sbi/inference/models/estimator_models.py @@ -94,7 +94,7 @@ def __init__( super().__init__() self.npe = NPE( - 1, + 8, output_embedding_dim, transforms=num_transforms, build=flow, @@ -144,4 +144,4 @@ def sample(self, x: torch.Tensor, shape=(1,)) -> torch.Tensor: """ samples_standardized = self.flow(x).sample(shape) - return self.standardize.transform(samples_standardized) + return self.standardize.transform(samples_standardized) \ No newline at end of file diff --git a/src/cryo_sbi/inference/train_npe_model.py b/src/cryo_sbi/inference/train_npe_model.py index 325aa26..b24f148 100644 --- a/src/cryo_sbi/inference/train_npe_model.py +++ b/src/cryo_sbi/inference/train_npe_model.py @@ -122,7 +122,7 @@ def npe_train_no_saving( step = GDStep(optimizer, clip=train_config["CLIP_GRADIENT"]) mean_loss = [] - print("Training neural netowrk:") + print("Training neural netowrk with all params:") estimator.train() with tqdm(range(epochs), unit="epoch") as tq: for epoch in tq: @@ -151,14 +151,24 @@ def npe_train_no_saving( num_pixels, pixel_size, ) - for _indices, _images in zip( - indices.split(train_config["BATCH_SIZE"]), - images.split(train_config["BATCH_SIZE"]), + flow_params = torch.cat( + ( + indices, + quaternions, + defocus.reshape(-1, 1), + b_factor.reshape(-1, 1), + snr.reshape(-1, 1), + ), + dim=1, + ) + for _flow_params, _images in zip( + flow_params.split(train_config["BATCH_SIZE"]), + images.split(train_config["BATCH_SIZE"]) ): losses.append( step( loss( - _indices.to(device, non_blocking=True), + _flow_params.to(device, non_blocking=True), _images.to(device, non_blocking=True), ) )