diff --git a/src/cryo_sbi/inference/priors.py b/src/cryo_sbi/inference/priors.py index f8fbb29..d45cd7f 100644 --- a/src/cryo_sbi/inference/priors.py +++ b/src/cryo_sbi/inference/priors.py @@ -110,10 +110,17 @@ def get_image_priors( ndims=1, ) - index_prior = zuko.distributions.BoxUniform( - lower=torch.tensor([0], dtype=torch.float32, device=device), - upper=torch.tensor([max_index], dtype=torch.float32, device=device), - ) + # preload models to layout + models = torch.load(image_config["MODEL_FILE"]) + if isinstance(models, list): + print("using multiple models per bin") + index_prior = MultiBinIndexPrior(*get_bin_layout(models), device) + else: + index_prior = zuko.distributions.BoxUniform( + lower=torch.tensor([0], dtype=torch.float32, device=device), + upper=torch.tensor([max_index], dtype=torch.float32, device=device), + ) + quaternion_prior = QuaternionPrior(device) if ( image_config.get("ROTATIONS") @@ -136,6 +143,37 @@ def get_image_priors( ) +def get_bin_layout(models): + models_per_bin = [] + layout = [] + index_to_cv = [] + start = 0 + for i, m in enumerate(models): + models_per_bin.append(m.shape[0]) + layout.append([j for j in range(start, start+m.shape[0])]) + index_to_cv += [i for _ in range(m.shape[0])] + start = start + m.shape[0] + return torch.tensor(models_per_bin), layout, torch.tensor(index_to_cv, dtype=torch.float32) + + +class MultiBinIndexPrior: + def __init__(self, models_per_bin, layout, index_to_cv, device) -> None: + self.num_bins = len(models_per_bin) + self.models_per_bin = models_per_bin + self.index_to_cv = index_to_cv + self.layout = layout + self.device = device + + def sample(self, shape) -> torch.Tensor: + samples = [] + idx_cv = torch.randint(0, self.num_bins, shape, device="cpu").flatten() + for i in range(len(idx_cv)): + idx = torch.randint(0, self.models_per_bin[idx_cv[i]], (1,), device="cpu") + samples.append(self.layout[idx_cv[i]][idx]) + + return torch.tensor(samples).reshape(-1, 1).to(torch.long).to(self.device) + + class QuaternionPrior: def __init__(self, device) -> None: self.device = device diff --git a/src/cryo_sbi/inference/train_npe_model.py b/src/cryo_sbi/inference/train_npe_model.py index 325aa26..a699ca5 100644 --- a/src/cryo_sbi/inference/train_npe_model.py +++ b/src/cryo_sbi/inference/train_npe_model.py @@ -97,9 +97,13 @@ def npe_train_no_saving( .to(torch.float32) ) else: - models = torch.load(image_config["MODEL_FILE"]).to(device).to(torch.float32) + models = torch.load(image_config["MODEL_FILE"]) + if isinstance(models, list): + models = torch.cat(models, dim=0).to(device).to(torch.float32) + print("model shape", models.shape) image_prior = get_image_priors(len(models) - 1, image_config, device="cpu") + index_to_cv = image_prior.priors[0].index_to_cv.to(device) prior_loader = PriorLoader( image_prior, batch_size=simulation_batch_size, num_workers=n_workers ) @@ -158,7 +162,7 @@ def npe_train_no_saving( losses.append( step( loss( - _indices.to(device, non_blocking=True), + index_to_cv[_indices].to(device, non_blocking=True), _images.to(device, non_blocking=True), ) ) diff --git a/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py b/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py index efcc7c9..dbc5503 100644 --- a/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py +++ b/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py @@ -7,7 +7,7 @@ from cryo_sbi.wpa_simulator.image_generation import project_density from cryo_sbi.wpa_simulator.noise import add_noise from cryo_sbi.wpa_simulator.normalization import gaussian_normalize_image -from cryo_sbi.inference.priors import get_image_priors +from cryo_sbi.inference.priors import get_image_priors, get_bin_layout from cryo_sbi.wpa_simulator.validate_image_config import check_image_params @@ -103,11 +103,11 @@ def _load_models(self) -> None: .to(torch.float32) ) elif self._config["MODEL_FILE"].endswith("pt"): - models = ( - torch.load(self._config["MODEL_FILE"]) - .to(self._device) - .to(torch.float32) - ) + models = torch.load(self._config["MODEL_FILE"]) + if isinstance(models, list): + self.num_models_per_bin, self.layout, self.index_to_cv = get_bin_layout(models) + models = torch.cat(models, dim=0) + models = models.to(self._device).to(torch.float32) else: raise NotImplementedError( @@ -150,9 +150,9 @@ def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=No assert isinstance( indices, torch.Tensor ), "Indices are not a torch.tensor, converting to torch.tensor." - assert ( - indices.dtype == torch.float32 - ), "Indices are not a torch.float32, converting to torch.float32." + #assert ( + # indices.dtype == torch.float32 + #), "Indices are not a torch.float32, converting to torch.float32." assert ( indices.ndim == 2 ), "Indices are not a 2D tensor, converting to 2D tensor. With shape (batch_size, 1)."