Skip to content

Commit

Permalink
more ensemble stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
aevans1 committed Sep 25, 2023
1 parent 3a21bfe commit 7684606
Show file tree
Hide file tree
Showing 41 changed files with 769,607 additions and 19 deletions.
7 changes: 6 additions & 1 deletion Lukes_folder/6wxb/6wxb_MMD_bandwidths.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3919,7 +3919,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "Python 3.9.15 64-bit",
"language": "python",
"name": "python3"
},
Expand All @@ -3934,6 +3934,11 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
},
"vscode": {
"interpreter": {
"hash": "7b7fbdd20bcc2083504065e64dd68e11295ac29c39a09e225403f090756a3e6a"
}
}
},
"nbformat": 4,
Expand Down
7 changes: 6 additions & 1 deletion Lukes_folder/6wxb/6wxb_MMD_index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "Python 3.9.15 64-bit",
"language": "python",
"name": "python3"
},
Expand All @@ -544,6 +544,11 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
},
"vscode": {
"interpreter": {
"hash": "7b7fbdd20bcc2083504065e64dd68e11295ac29c39a09e225403f090756a3e6a"
}
}
},
"nbformat": 4,
Expand Down
167 changes: 167 additions & 0 deletions Lukes_folder/compute_posterior_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import argparse
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as transforms
import json
import mrcfile
from itertools import islice

from cryo_sbi.inference.models import build_models
from cryo_sbi.inference import priors
from cryo_sbi.inference.priors import get_image_priors, PriorLoader
from cryo_sbi.wpa_simulator.cryo_em_simulator import cryo_em_simulator

def main():
print("hey!")
device = "cuda" # Device for computations
save_figures = False

#snr_values = [0.001, 0.01, 0.1, 1.0]
snr_values = ["uniform"]
#num_batch_values = [1, 10, 100]
num_batch_values = [100]
for i in range(len(snr_values)):
for j in range(len(num_batch_values)):
snr_val = snr_values[i]

# Load and modify image config file
image_config_file = "image_params_cryoer_cryobife_high_snr.json"
image_config = json.load(open(image_config_file))

if snr_val == "uniform":
image_config["SNR"] = [0.001, 1.0]
else:
image_config["SNR"] = [snr_val, snr_val]
snr_check = image_config["SNR"]
print(f"SNR range: {snr_check}")

# Load Neural Posterior Estimator
train_config = json.load(open("Lars_hsp90/resnet18_encoder.json"))
estimator = build_models.build_npe_flow_model(train_config)
estimator.load_state_dict(torch.load("Lars_hsp90/hsp90_posterior.estimator"))
estimator.cuda()
estimator.eval();

num_batches = num_batch_values[j]
simulation_batch_size = 1024
print(f"num batches: {num_batches}")

log_posterior_mat, full_indices = batch_simulate_log_posterior(image_config, estimator, num_batches, device="cuda", simulation_batch_size=simulation_batch_size)

if snr_val == "uniform":
fname = f"posterior_matrices/posterior_sampling_snr_uniform_num_batches_idx_{j}"
else:
fname = f"posterior_matrices/posterior_sampling_snr_idx_{i}_num_batches_idx_{j}"
#np.savez(f"{fname}.npz", log_posterior_mat=log_posterior_mat, indices=full_indices)
#np.savetxt(f"{fname}.txt", log_posterior_mat)
plt.hist(full_indices, bins=20, density=True);
plt.xlabel(r"$\theta$");
plt.savefig("posterior_matrices/index_distribution.png", dpi=300)

def batch_simulate_log_posterior(
image_config,
estimator,
num_batches: int,
n_workers: int = 1,
device: str = "cpu",
simulation_batch_size: int = 1024
) -> None:
"""
Args:
image_config (str): path to image config file
n_workers (int, optional): number of workers. Defaults to 1.
device (str, optional): training device. Defaults to "cpu".
saving_frequency (int, optional): frequency of saving model. Defaults to 20.
whiten_filter (Union[None, str], optional): path to whiten filter. Defaults to None.
Raises:
Warning: No model state dict specified! --model_state_dict is empty
Returns:
Posterior matrix for images and models
"""

if image_config["MODEL_FILE"].endswith("npy"):

# hsp90 gets special treatment
if "hsp90" in image_config["MODEL_FILE"]:
models = (
torch.from_numpy(
np.load(image_config["MODEL_FILE"])[:, 0, :, :],
).to(device).to(torch.float32)
)
else:
models = (
torch.from_numpy(
np.load(image_config["MODEL_FILE"]),
)
.to(device)
.to(torch.float32)
)
else:
models = torch.load(
image_config["MODEL_FILE"],
dtype=torch.float32,
device=device
)

# Load Prior and relevant parameter values
image_prior = get_image_priors(len(models) - 1, image_config, device="cpu")
prior_loader = PriorLoader(image_prior,
batch_size=simulation_batch_size,
num_workers=n_workers)
num_pixels = torch.tensor(image_config["N_PIXELS"],
dtype=torch.float32,
device=device)
pixel_size = torch.tensor(image_config["PIXEL_SIZE"],
dtype=torch.float32,
device=device)

# Initilize posterior matrix
num_sim = int(num_batches*simulation_batch_size)
sampling_indices = torch.arange(0, models.shape[0] + 1, 1, dtype=torch.float32).reshape(-1, 1).to(device)
norm_indices = estimator.standardize(sampling_indices.to(device))
posterior_mat = np.zeros((num_sim, sampling_indices.shape[0]))
full_indices = np.zeros(num_sim)

j = 0
for parameters in islice(prior_loader, num_batches):

# Sample a batch of parameters, simulate a batch of images
indices, quaternions, sigma, shift, defocus, b_factor, amp, snr = parameters
images = cryo_em_simulator(
models,
indices.to(device, non_blocking=True),
quaternions.to(device, non_blocking=True),
sigma.to(device, non_blocking=True),
shift.to(device, non_blocking=True),
defocus.to(device, non_blocking=True),
b_factor.to(device, non_blocking=True),
amp.to(device, non_blocking=True),
snr.to(device, non_blocking=True),
num_pixels,
pixel_size,
)
batch_size = images.shape[0]

# Evaluate the posterior of the flow conditioned on the image batch
with torch.no_grad():

# Get posterior function for each image in batch
flow_at_images = estimator.flow(images.cuda(non_blocking=True))

# Evaluate posterior function for each image in batch, at each model index
indices_for_flow = norm_indices.repeat(1, images.shape[0]).unsqueeze(dim=2) # each flow needs its own set of parameters
logprobs = flow_at_images.log_prob(indices_for_flow).to(device)
probs = torch.exp(logprobs) / 10 # normalizing so that probs add to 1 (from change of variables to standardizing to [01, 1])

# Store posterior values, store model indices sampled
posterior_mat[j:j + batch_size, :] = probs.T.detach().cpu().numpy()
full_indices[j:j + batch_size] = indices.detach().cpu().numpy().flatten()
j += batch_size
log_posterior_mat = np.log(posterior_mat)
return log_posterior_mat, full_indices

if __name__ == "__main__":
main()
Loading

0 comments on commit 7684606

Please sign in to comment.