Skip to content

Commit

Permalink
added code to have ultiple structures per cv
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jun 28, 2024
1 parent a60c974 commit c6cbf99
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 15 deletions.
46 changes: 42 additions & 4 deletions src/cryo_sbi/inference/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,17 @@ def get_image_priors(
ndims=1,
)

index_prior = zuko.distributions.BoxUniform(
lower=torch.tensor([0], dtype=torch.float32, device=device),
upper=torch.tensor([max_index], dtype=torch.float32, device=device),
)
# preload models to layout
models = torch.load(image_config["MODEL_FILE"])
if isinstance(models, list):
print("using multiple models per bin")
index_prior = MultiBinIndexPrior(*get_bin_layout(models), device)
else:
index_prior = zuko.distributions.BoxUniform(
lower=torch.tensor([0], dtype=torch.float32, device=device),
upper=torch.tensor([max_index], dtype=torch.float32, device=device),
)

quaternion_prior = QuaternionPrior(device)
if (
image_config.get("ROTATIONS")
Expand All @@ -136,6 +143,37 @@ def get_image_priors(
)


def get_bin_layout(models):
models_per_bin = []
layout = []
index_to_cv = []
start = 0
for i, m in enumerate(models):
models_per_bin.append(m.shape[0])
layout.append([j for j in range(start, start+m.shape[0])])
index_to_cv += [i for _ in range(m.shape[0])]
start = start + m.shape[0]
return torch.tensor(models_per_bin), layout, torch.tensor(index_to_cv, dtype=torch.float32)


class MultiBinIndexPrior:
def __init__(self, models_per_bin, layout, index_to_cv, device) -> None:
self.num_bins = len(models_per_bin)
self.models_per_bin = models_per_bin
self.index_to_cv = index_to_cv
self.layout = layout
self.device = device

def sample(self, shape) -> torch.Tensor:
samples = []
idx_cv = torch.randint(0, self.num_bins, shape, device="cpu").flatten()
for i in range(len(idx_cv)):
idx = torch.randint(0, self.models_per_bin[idx_cv[i]], (1,), device="cpu")
samples.append(self.layout[idx_cv[i]][idx])

return torch.tensor(samples).reshape(-1, 1).to(torch.long).to(self.device)


class QuaternionPrior:
def __init__(self, device) -> None:
self.device = device
Expand Down
8 changes: 6 additions & 2 deletions src/cryo_sbi/inference/train_npe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,13 @@ def npe_train_no_saving(
.to(torch.float32)
)
else:
models = torch.load(image_config["MODEL_FILE"]).to(device).to(torch.float32)
models = torch.load(image_config["MODEL_FILE"])
if isinstance(models, list):
models = torch.cat(models, dim=0).to(device).to(torch.float32)
print("model shape", models.shape)

image_prior = get_image_priors(len(models) - 1, image_config, device="cpu")
index_to_cv = image_prior.priors[0].index_to_cv.to(device)
prior_loader = PriorLoader(
image_prior, batch_size=simulation_batch_size, num_workers=n_workers
)
Expand Down Expand Up @@ -158,7 +162,7 @@ def npe_train_no_saving(
losses.append(
step(
loss(
_indices.to(device, non_blocking=True),
index_to_cv[_indices].to(device, non_blocking=True),
_images.to(device, non_blocking=True),
)
)
Expand Down
18 changes: 9 additions & 9 deletions src/cryo_sbi/wpa_simulator/cryo_em_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from cryo_sbi.wpa_simulator.image_generation import project_density
from cryo_sbi.wpa_simulator.noise import add_noise
from cryo_sbi.wpa_simulator.normalization import gaussian_normalize_image
from cryo_sbi.inference.priors import get_image_priors
from cryo_sbi.inference.priors import get_image_priors, get_bin_layout
from cryo_sbi.wpa_simulator.validate_image_config import check_image_params


Expand Down Expand Up @@ -103,11 +103,11 @@ def _load_models(self) -> None:
.to(torch.float32)
)
elif self._config["MODEL_FILE"].endswith("pt"):
models = (
torch.load(self._config["MODEL_FILE"])
.to(self._device)
.to(torch.float32)
)
models = torch.load(self._config["MODEL_FILE"])
if isinstance(models, list):
self.num_models_per_bin, self.layout, self.index_to_cv = get_bin_layout(models)
models = torch.cat(models, dim=0)
models = models.to(self._device).to(torch.float32)

else:
raise NotImplementedError(
Expand Down Expand Up @@ -150,9 +150,9 @@ def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=No
assert isinstance(
indices, torch.Tensor
), "Indices are not a torch.tensor, converting to torch.tensor."
assert (
indices.dtype == torch.float32
), "Indices are not a torch.float32, converting to torch.float32."
#assert (
# indices.dtype == torch.float32
#), "Indices are not a torch.float32, converting to torch.float32."
assert (
indices.ndim == 2
), "Indices are not a 2D tensor, converting to 2D tensor. With shape (batch_size, 1)."
Expand Down

0 comments on commit c6cbf99

Please sign in to comment.