Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v3.4.2: AMP for ab-initio reconstruction; faster pose parsing #419

Merged
merged 9 commits into from
Nov 4, 2024
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ For any feedback, questions, or bugs, please file a Github issue or start a Gith

### New in Version 3.4.x
* [NEW] `cryodrgn plot_classes` for analysis visualizations colored by a given set of class labels
* support for RELION 3.1 .star files with separate optics tables
* support for np.float16 number formats used in RELION .mrcs outputs
* implementing [automatic mixed-precision training](https://pytorch.org/docs/stable/amp.html)
for ab-initio reconstruction for 2-4x speedup
* support for RELION 3.1 .star files with separate optics tables, np.float16 number formats used in RELION .mrcs outputs
* `cryodrgn backproject_voxel` produces cryoSPARC-style FSC curve plots with phase-randomization correction of
automatically generated tight masks
* `cryodrgn downsample` can create a new .star or .txt image stack from the corresponding stack format instead of
Expand Down
216 changes: 147 additions & 69 deletions cryodrgn/commands/abinit_het.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import pickle
import sys
import contextlib
import logging
from datetime import datetime as dt
import numpy as np
Expand All @@ -32,6 +33,11 @@
from cryodrgn.models import HetOnlyVAE, unparallelize
from cryodrgn.pose_search import PoseSearch

try:
import apex.amp as amp # type: ignore # PYR01
except ImportError:
pass

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -223,6 +229,12 @@ def add_args(parser):
type=int,
help="If set, reset the optimizer every N epochs",
)
group.add_argument(
"--no-amp",
action="store_false",
dest="amp",
help="Do not use mixed-precision training for accelerating training",
)
group.add_argument(
"--multigpu",
action="store_true",
Expand Down Expand Up @@ -451,6 +463,8 @@ def train(
enc_only=False,
poses=None,
ctf_params=None,
use_amp=False,
scaler=None,
):
y, yt = minibatch
use_tilt = yt is not None
Expand All @@ -470,87 +484,104 @@ def train(
# TODO: Center image?
# We do this in pose-supervised train_vae

# VAE inference of z
model.train()
optim.zero_grad()
input_ = (y, yt) if use_tilt else (y,)
if ctf_i is not None:
input_ = (x * ctf_i.sign() for x in input_) # phase flip by the ctf

_model = unparallelize(model)
assert isinstance(_model, HetOnlyVAE)
z_mu, z_logvar = _model.encode(*input_)
z = _model.reparameterize(z_mu, z_logvar)

lamb = eq_loss = None
if equivariance is not None:
lamb, equivariance_loss = equivariance
eq_loss = equivariance_loss(y, z_mu)

# pose inference
if poses is not None: # use provided poses
rot = poses[0]
trans = poses[1]
else: # pose search
model.eval()
with torch.no_grad():
rot, trans, _base_pose = ps.opt_theta_trans(
y,
z=z,
images_tilt=None if enc_only else yt,
ctf_i=ctf_i,
)
model.train()

# reconstruct circle of pixels instead of whole image
mask = lattice.get_circular_mask(L)
if scaler is not None:
amp_mode = torch.cuda.amp.autocast_mode.autocast()
else:
amp_mode = contextlib.nullcontext()

def gen_slice(R):
slice_ = model(lattice.coords[mask] @ R, z).view(B, -1)
with amp_mode:
# VAE inference of z
model.train()
optim.zero_grad()
input_ = (y, yt) if use_tilt else (y,)
if ctf_i is not None:
slice_ *= ctf_i.view(B, -1)[:, mask]
return slice_
input_ = (x * ctf_i.sign() for x in input_) # phase flip by the ctf

def translate(img):
img = lattice.translate_ht(img, trans.unsqueeze(1), mask)
return img.view(B, -1)
_model = unparallelize(model)
assert isinstance(_model, HetOnlyVAE)
z_mu, z_logvar = _model.encode(*input_)
z = _model.reparameterize(z_mu, z_logvar)

lamb = eq_loss = None
if equivariance is not None:
lamb, equivariance_loss = equivariance
eq_loss = equivariance_loss(y, z_mu)

# pose inference
if poses is not None: # use provided poses
rot = poses[0]
trans = poses[1]
else: # pose search
model.eval()
with torch.no_grad():
rot, trans, _base_pose = ps.opt_theta_trans(
y,
z=z,
images_tilt=None if enc_only else yt,
ctf_i=ctf_i,
)
model.train()

y = y.view(B, -1)[:, mask]
if use_tilt:
yt = yt.view(B, -1)[:, mask]
y = translate(y)
if use_tilt:
yt = translate(yt)
# reconstruct circle of pixels instead of whole image
mask = lattice.get_circular_mask(L)

if use_tilt:
gen_loss = 0.5 * F.mse_loss(gen_slice(rot), y) + 0.5 * F.mse_loss(
gen_slice(bnb.tilt @ rot), yt # type: ignore # noqa: F821
)
else:
gen_loss = F.mse_loss(gen_slice(rot), y)
def gen_slice(R):
slice_ = model(lattice.coords[mask] @ R, z).view(B, -1)
if ctf_i is not None:
slice_ *= ctf_i.view(B, -1)[:, mask]
return slice_

# latent loss
kld = torch.mean(
-0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp(), dim=1), dim=0
)
if torch.isnan(kld):
logger.info(z_mu[0])
logger.info(z_logvar[0])
raise RuntimeError("KLD is nan")
def translate(img):
img = lattice.translate_ht(img, trans.unsqueeze(1), mask)
return img.view(B, -1)

if beta_control is None:
loss = gen_loss + beta * kld / mask.sum().float()
else:
loss = gen_loss + beta_control * (beta - kld) ** 2 / mask.sum().float()
y = y.view(B, -1)[:, mask]
if use_tilt:
yt = yt.view(B, -1)[:, mask]
y = translate(y)
if use_tilt:
yt = translate(yt)

if loss is not None and eq_loss is not None:
loss += lamb * eq_loss
if use_tilt:
gen_loss = 0.5 * F.mse_loss(gen_slice(rot), y) + 0.5 * F.mse_loss(
gen_slice(bnb.tilt @ rot), yt # type: ignore # noqa: F821
)
else:
gen_loss = F.mse_loss(gen_slice(rot), y)

loss.backward()
# latent loss
kld = torch.mean(
-0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp(), dim=1), dim=0
)
if torch.isnan(kld):
logger.info(z_mu[0])
logger.info(z_logvar[0])
raise RuntimeError("KLD is nan")

if beta_control is None:
loss = gen_loss + beta * kld / mask.sum().float()
else:
loss = gen_loss + beta_control * (beta - kld) ** 2 / mask.sum().float()

if loss is not None and eq_loss is not None:
loss += lamb * eq_loss

if use_amp:
if scaler is not None:
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()
else: # apex.amp mixed precision
with amp.scale_loss(loss, optim) as scaled_loss:
scaled_loss.backward()
optim.step()
else:
loss.backward()
optim.step()

optim.step()
save_pose = [rot.detach().cpu().numpy()]
save_pose.append(trans.detach().cpu().numpy())

return (
gen_loss.item(),
kld.item(),
Expand Down Expand Up @@ -833,6 +864,51 @@ def main(args):

optim = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)

# Mixed precision training
scaler = None
if args.amp:
if args.batch_size % 8 != 0:
logger.warning(
f"Batch size {args.batch_size} not divisible by 8 "
f"and thus not optimal for AMP training!"
)
if (D - 1) % 8 != 0:
logger.warning(
f"Image size {D - 1} not divisible by 8 "
f"and thus not optimal for AMP training!"
)

if args.pdim % 8 != 0:
logger.warning(
f"Decoder hidden layer dimension {args.pdim} not divisible by 8 "
f"and thus not optimal for AMP training!"
)

# also check e.g. enc_mask dim?
if args.qdim % 8 != 0:
logger.warning(
f"Decoder hidden layer dimension {args.qdim} not divisible by 8 "
f"and thus not optimal for AMP training!"
)

if args.zdim % 8 != 0:
logger.warning(
f"Z dimension {args.zdim} is not a multiple of 8 "
"-- AMP training speedup is not optimized!"
)
if in_dim % 8 != 0:
logger.warning(
f"Masked input image dimension {in_dim} is not a mutiple of 8 "
"-- AMP training speedup is not optimized!"
)

# mixed precision with apex.amp
try:
model, optim = amp.initialize(model, optim, opt_level="O1")
# mixed precision with pytorch (v1.6+)
except: # noqa: E722
scaler = torch.cuda.amp.grad_scaler.GradScaler()

if args.load == "latest":
args = get_latest(args)

Expand Down Expand Up @@ -1007,6 +1083,8 @@ def main(args):
enc_only=args.enc_only,
poses=p,
ctf_params=ctf_i,
use_amp=args.amp,
scaler=scaler,
)
# logging
poses.append((ind.cpu().numpy(), pose))
Expand Down
Loading