-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
41 changed files
with
769,607 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
import argparse | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import torch | ||
import torchvision.transforms as transforms | ||
import json | ||
import mrcfile | ||
from itertools import islice | ||
|
||
from cryo_sbi.inference.models import build_models | ||
from cryo_sbi.inference import priors | ||
from cryo_sbi.inference.priors import get_image_priors, PriorLoader | ||
from cryo_sbi.wpa_simulator.cryo_em_simulator import cryo_em_simulator | ||
|
||
def main(): | ||
print("hey!") | ||
device = "cuda" # Device for computations | ||
save_figures = False | ||
|
||
#snr_values = [0.001, 0.01, 0.1, 1.0] | ||
snr_values = ["uniform"] | ||
#num_batch_values = [1, 10, 100] | ||
num_batch_values = [100] | ||
for i in range(len(snr_values)): | ||
for j in range(len(num_batch_values)): | ||
snr_val = snr_values[i] | ||
|
||
# Load and modify image config file | ||
image_config_file = "image_params_cryoer_cryobife_high_snr.json" | ||
image_config = json.load(open(image_config_file)) | ||
|
||
if snr_val == "uniform": | ||
image_config["SNR"] = [0.001, 1.0] | ||
else: | ||
image_config["SNR"] = [snr_val, snr_val] | ||
snr_check = image_config["SNR"] | ||
print(f"SNR range: {snr_check}") | ||
|
||
# Load Neural Posterior Estimator | ||
train_config = json.load(open("Lars_hsp90/resnet18_encoder.json")) | ||
estimator = build_models.build_npe_flow_model(train_config) | ||
estimator.load_state_dict(torch.load("Lars_hsp90/hsp90_posterior.estimator")) | ||
estimator.cuda() | ||
estimator.eval(); | ||
|
||
num_batches = num_batch_values[j] | ||
simulation_batch_size = 1024 | ||
print(f"num batches: {num_batches}") | ||
|
||
log_posterior_mat, full_indices = batch_simulate_log_posterior(image_config, estimator, num_batches, device="cuda", simulation_batch_size=simulation_batch_size) | ||
|
||
if snr_val == "uniform": | ||
fname = f"posterior_matrices/posterior_sampling_snr_uniform_num_batches_idx_{j}" | ||
else: | ||
fname = f"posterior_matrices/posterior_sampling_snr_idx_{i}_num_batches_idx_{j}" | ||
#np.savez(f"{fname}.npz", log_posterior_mat=log_posterior_mat, indices=full_indices) | ||
#np.savetxt(f"{fname}.txt", log_posterior_mat) | ||
plt.hist(full_indices, bins=20, density=True); | ||
plt.xlabel(r"$\theta$"); | ||
plt.savefig("posterior_matrices/index_distribution.png", dpi=300) | ||
|
||
def batch_simulate_log_posterior( | ||
image_config, | ||
estimator, | ||
num_batches: int, | ||
n_workers: int = 1, | ||
device: str = "cpu", | ||
simulation_batch_size: int = 1024 | ||
) -> None: | ||
""" | ||
Args: | ||
image_config (str): path to image config file | ||
n_workers (int, optional): number of workers. Defaults to 1. | ||
device (str, optional): training device. Defaults to "cpu". | ||
saving_frequency (int, optional): frequency of saving model. Defaults to 20. | ||
whiten_filter (Union[None, str], optional): path to whiten filter. Defaults to None. | ||
Raises: | ||
Warning: No model state dict specified! --model_state_dict is empty | ||
Returns: | ||
Posterior matrix for images and models | ||
""" | ||
|
||
if image_config["MODEL_FILE"].endswith("npy"): | ||
|
||
# hsp90 gets special treatment | ||
if "hsp90" in image_config["MODEL_FILE"]: | ||
models = ( | ||
torch.from_numpy( | ||
np.load(image_config["MODEL_FILE"])[:, 0, :, :], | ||
).to(device).to(torch.float32) | ||
) | ||
else: | ||
models = ( | ||
torch.from_numpy( | ||
np.load(image_config["MODEL_FILE"]), | ||
) | ||
.to(device) | ||
.to(torch.float32) | ||
) | ||
else: | ||
models = torch.load( | ||
image_config["MODEL_FILE"], | ||
dtype=torch.float32, | ||
device=device | ||
) | ||
|
||
# Load Prior and relevant parameter values | ||
image_prior = get_image_priors(len(models) - 1, image_config, device="cpu") | ||
prior_loader = PriorLoader(image_prior, | ||
batch_size=simulation_batch_size, | ||
num_workers=n_workers) | ||
num_pixels = torch.tensor(image_config["N_PIXELS"], | ||
dtype=torch.float32, | ||
device=device) | ||
pixel_size = torch.tensor(image_config["PIXEL_SIZE"], | ||
dtype=torch.float32, | ||
device=device) | ||
|
||
# Initilize posterior matrix | ||
num_sim = int(num_batches*simulation_batch_size) | ||
sampling_indices = torch.arange(0, models.shape[0] + 1, 1, dtype=torch.float32).reshape(-1, 1).to(device) | ||
norm_indices = estimator.standardize(sampling_indices.to(device)) | ||
posterior_mat = np.zeros((num_sim, sampling_indices.shape[0])) | ||
full_indices = np.zeros(num_sim) | ||
|
||
j = 0 | ||
for parameters in islice(prior_loader, num_batches): | ||
|
||
# Sample a batch of parameters, simulate a batch of images | ||
indices, quaternions, sigma, shift, defocus, b_factor, amp, snr = parameters | ||
images = cryo_em_simulator( | ||
models, | ||
indices.to(device, non_blocking=True), | ||
quaternions.to(device, non_blocking=True), | ||
sigma.to(device, non_blocking=True), | ||
shift.to(device, non_blocking=True), | ||
defocus.to(device, non_blocking=True), | ||
b_factor.to(device, non_blocking=True), | ||
amp.to(device, non_blocking=True), | ||
snr.to(device, non_blocking=True), | ||
num_pixels, | ||
pixel_size, | ||
) | ||
batch_size = images.shape[0] | ||
|
||
# Evaluate the posterior of the flow conditioned on the image batch | ||
with torch.no_grad(): | ||
|
||
# Get posterior function for each image in batch | ||
flow_at_images = estimator.flow(images.cuda(non_blocking=True)) | ||
|
||
# Evaluate posterior function for each image in batch, at each model index | ||
indices_for_flow = norm_indices.repeat(1, images.shape[0]).unsqueeze(dim=2) # each flow needs its own set of parameters | ||
logprobs = flow_at_images.log_prob(indices_for_flow).to(device) | ||
probs = torch.exp(logprobs) / 10 # normalizing so that probs add to 1 (from change of variables to standardizing to [01, 1]) | ||
|
||
# Store posterior values, store model indices sampled | ||
posterior_mat[j:j + batch_size, :] = probs.T.detach().cpu().numpy() | ||
full_indices[j:j + batch_size] = indices.detach().cpu().numpy().flatten() | ||
j += batch_size | ||
log_posterior_mat = np.log(posterior_mat) | ||
return log_posterior_mat, full_indices | ||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.