Skip to content

Commit

Permalink
added support for simple tensorboard logger
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jul 11, 2024
1 parent c6cbf99 commit c58f616
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 8 deletions.
8 changes: 8 additions & 0 deletions src/cryo_sbi/inference/command_line_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def cl_npe_train_no_saving():
cl_parser.add_argument(
"--saving_freq", action="store", type=int, required=False, default=20
)
cl_parser.add_argument(
"--val_set", action="store", type=str, required=False, default=None
)
cl_parser.add_argument(
"--val_freq", action="store", type=int, required=False, default=10
)
cl_parser.add_argument(
"--simulation_batch_size",
action="store",
Expand All @@ -59,4 +65,6 @@ def cl_npe_train_no_saving():
device=args.train_device,
saving_frequency=args.saving_freq,
simulation_batch_size=args.simulation_batch_size,
validation_set=args.val_set,
validation_frequency=args.val_freq,
)
64 changes: 56 additions & 8 deletions src/cryo_sbi/inference/train_npe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,20 @@
import torch
import numpy as np
import torch.optim as optim
from torch.utils.data import TensorDataset
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from lampe.data import JointLoader, H5Dataset
from lampe.inference import NPELoss
from lampe.utils import GDStep
from itertools import islice
import matplotlib.pyplot as plt

from cryo_sbi.inference.priors import get_image_priors, PriorLoader
from cryo_sbi.inference.models.build_models import build_npe_flow_model
from cryo_sbi.inference.validate_train_config import check_train_params
from cryo_sbi.wpa_simulator.cryo_em_simulator import cryo_em_simulator
from cryo_sbi.wpa_simulator.validate_image_config import check_image_params
from cryo_sbi.inference.validate_train_config import check_train_params
import cryo_sbi.utils.image_utils as img_utils
from cryo_sbi.utils.estimator_utils import sample_posterior, evaluate_log_prob


def load_model(
Expand Down Expand Up @@ -56,6 +55,8 @@ def npe_train_no_saving(
device: str = "cpu",
saving_frequency: int = 20,
simulation_batch_size: int = 1024,
validation_set: Union[str, None] = None,
validation_frequency: int = 10
) -> None:
"""
Train NPE model by simulating training data on the fly.
Expand Down Expand Up @@ -84,9 +85,12 @@ def npe_train_no_saving(
train_config = json.load(open(train_config))
check_train_params(train_config)
image_config = json.load(open(image_config))
check_image_params(image_config)

assert simulation_batch_size >= train_config["BATCH_SIZE"]
assert simulation_batch_size % train_config["BATCH_SIZE"] == 0
steps_per_epoch = simulation_batch_size // train_config["BATCH_SIZE"]
epoch_repeats = 100 # number of times to simulate a batch of images per epoch

if image_config["MODEL_FILE"].endswith("npy"):
models = (
Expand All @@ -104,6 +108,7 @@ def npe_train_no_saving(

image_prior = get_image_priors(len(models) - 1, image_config, device="cpu")
index_to_cv = image_prior.priors[0].index_to_cv.to(device)
max_index = index_to_cv.max().cpu()
prior_loader = PriorLoader(
image_prior, batch_size=simulation_batch_size, num_workers=n_workers
)
Expand All @@ -120,18 +125,35 @@ def npe_train_no_saving(
)

loss = NPELoss(estimator)
optimizer = optim.AdamW(
estimator.parameters(), lr=train_config["LEARNING_RATE"], weight_decay=0.001
)
optimizer = optim.AdamW(estimator.parameters(), lr=train_config["LEARNING_RATE"], weight_decay=train_config["WEIGHT_DECAY"])
step = GDStep(optimizer, clip=train_config["CLIP_GRADIENT"])
mean_loss = []

if validation_set is not None:
validation_set = torch.load(validation_set)
assert isinstance(validation_set, dict), "Validation set must be a dictionary"
assert "IMAGES" in validation_set, "Validation set must contain images"
assert "INDICES" in validation_set, "Validation set must contain ground truth indices"

print("Initializing tensorboard writer")
writer = SummaryWriter()

if validation_set is not None:
num_validation_images = validation_set["IMAGES"].shape[0]
for i in range(num_validation_images):
fig, axes = plt.subplots(1, 1, figsize=(5, 5))
axes.imshow(validation_set["IMAGES"][i].cpu().numpy(), cmap="gray", vmax=1.5, vmin=-1.5)
axes.axis("off")
writer.add_figure(f"Validation/images", fig, global_step=i)
plt.close(fig)
writer.flush()

print("Training neural netowrk:")
estimator.train()
with tqdm(range(epochs), unit="epoch") as tq:
for epoch in tq:
losses = []
for parameters in islice(prior_loader, 100):
for parameters in islice(prior_loader, epoch_repeats):
(
indices,
quaternions,
Expand Down Expand Up @@ -171,8 +193,34 @@ def npe_train_no_saving(

tq.set_postfix(loss=losses.mean().item())
mean_loss.append(losses.mean().item())
current_step = (epoch + 1) * steps_per_epoch * epoch_repeats

writer.add_scalar("Loss/mean", losses.mean().item(), current_step)
writer.add_scalar("Loss/std", losses.std().item(), current_step)
writer.add_scalar("Loss/last", losses[-1].item(), current_step)

if epoch % saving_frequency == 0:
torch.save(estimator.state_dict(), estimator_file + f"_epoch={epoch}")

if validation_set is not None and epoch % validation_frequency == 0:
estimator.eval()
with torch.no_grad():
val_posterior_samples = sample_posterior(
estimator, validation_set["IMAGES"], num_samples=5000, device=device, batch_size=train_config["BATCH_SIZE"]
)
for i in range(num_validation_images):
writer.add_histogram(
f"Validation/posterior_{i}_index={validation_set['INDICES'][i].item()}",
val_posterior_samples[:, i],
global_step=current_step
)
estimator.train()

writer.add_hparams(
train_config,
{"hparam/best_loss": min(mean_loss), "hparam/last_loss": mean_loss[-1]}
)
writer.flush()
writer.close()
torch.save(estimator.state_dict(), estimator_file)
torch.save(torch.tensor(mean_loss), loss_file)
1 change: 1 addition & 0 deletions src/cryo_sbi/inference/validate_train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def check_train_params(config: dict) -> None:
"BATCH_SIZE",
"THETA_SHIFT",
"THETA_SCALE",
"WEIGHT_DECAY"
]

for key in needed_keys:
Expand Down

0 comments on commit c58f616

Please sign in to comment.