Skip to content

Commit

Permalink
fixing issues with notebook generation for landscape analysis and add…
Browse files Browse the repository at this point in the history
…ing relevant unit tests
  • Loading branch information
michal-g committed Oct 11, 2024
1 parent b51a3b2 commit 0f9c25f
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 56 deletions.
12 changes: 3 additions & 9 deletions cryodrgn/commands/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
logger = logging.getLogger(__name__)


def add_args(parser):
def add_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"workdir", type=os.path.abspath, help="Directory with cryoDRGN results"
)
Expand Down Expand Up @@ -402,7 +402,8 @@ def gen_volumes(self, outdir, z_values):
analysis.gen_volumes(self.weights, self.config, zfile, outdir, **self.vol_args)


def main(args):
def main(args: argparse.Namespace) -> None:
matplotlib.use("Agg") # non-interactive backend
t1 = dt.now()
E = args.epoch
workdir = args.workdir
Expand Down Expand Up @@ -527,10 +528,3 @@ def main(args):
nbformat.write(filter_ntbook, f)

logger.info(f"Finished in {dt.now() - t1}")


if __name__ == "__main__":
matplotlib.use("Agg") # non-interactive backend
parser = argparse.ArgumentParser(description=__doc__)
add_args(parser)
main(parser.parse_args())
58 changes: 48 additions & 10 deletions cryodrgn/commands/analyze_landscape_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,20 @@
"""
import argparse
import os
import os.path
import pprint
import shutil
from datetime import datetime as dt
import logging
import nbformat

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.dataloader import default_collate
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data.dataloader import default_collate
from torch.utils.data import Dataset, DataLoader

import cryodrgn
from cryodrgn import config, utils
from cryodrgn.models import HetOnlyVAE, ResidLinearMLP
Expand Down Expand Up @@ -63,6 +64,7 @@ def add_args(parser: argparse.ArgumentParser) -> None:
group = parser.add_argument_group("Volume generation arguments")
group.add_argument(
"-N",
"--training-volumes",
type=int,
default=10000,
help="Number of training volumes to generate (default: %(default)s)",
Expand Down Expand Up @@ -178,10 +180,12 @@ def generate_and_map_volumes(
zfile, cfg, weights, mask_mrc, pca_obj_pkl, landscape_dir, outdir, args
):
# Sample z
logger.info(f"Sampling {args.N} particles from {zfile}")
logger.info(f"Sampling {args.training_volumes} particles from {zfile}")
np.random.seed(args.seed)
z_all = utils.load_pkl(zfile)
ind = np.array(sorted(np.random.choice(len(z_all), args.N, replace=False))) # type: ignore
ind = np.array(
sorted(np.random.choice(len(z_all), args.training_volumes, replace=False))
) # type: ignore
z_sample = z_all[ind]
utils.save_pkl(z_sample, f"{outdir}/z.sampled.pkl")
utils.save_pkl(ind, f"{outdir}/ind.sampled.pkl")
Expand Down Expand Up @@ -223,7 +227,7 @@ def generate_and_map_volumes(
t1 = dt.now()
embeddings = []
for i, zz in enumerate(z):
if i % 100 == 0:
if i % 1 == 0:
logger.info(i)

if args.downsample:
Expand All @@ -250,7 +254,7 @@ def generate_and_map_volumes(
embeddings = np.array(embeddings).reshape(len(z), -1).astype(np.float32)

td = dt.now() - t1
logger.info(f"Finished generating {args.N} volumes in {td}")
logger.info(f"Finished generating {args.training_volumes} volumes in {td}")

return z, embeddings

Expand Down Expand Up @@ -286,7 +290,7 @@ def train_model(x, y, outdir, zfile, args):
device
)
logger.info(model)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

# Train
for epoch in range(1, args.epochs + 1):
Expand All @@ -312,6 +316,7 @@ def train_model(x, y, outdir, zfile, args):
def main(args: argparse.Namespace) -> None:
t1 = dt.now()
logger.info(args)

E = args.epoch
workdir = args.workdir
zfile = f"{workdir}/z.{E}.pkl"
Expand All @@ -335,6 +340,17 @@ def main(args: argparse.Namespace) -> None:
pca_obj_pkl
), f"{pca_obj_pkl} missing. Did you run cryodrgn analyze_landscape?"

kmeans_folder = [
p for p in os.listdir(landscape_dir) if p.startswith("clustering_L2_")
]
if len(kmeans_folder) == 0:
raise RuntimeError(
"No clustering folders `clustering_L2_` found. "
"Did you run cryodrgn analyze_landscape?"
)
kmeans_folder = kmeans_folder[0]
link_method, kmeans_K = kmeans_folder.split("_")[-2:]

logger.info(f"Saving results to {outdir}")
if not os.path.exists(outdir):
os.mkdir(outdir)
Expand All @@ -361,13 +377,35 @@ def main(args: argparse.Namespace) -> None:
utils.save_pkl(embeddings_all, f"{outdir}/vol_pca_all.pkl")

# Copy viz notebook
out_ipynb = f"{landscape_dir}/cryoDRGN_analyze_landscape.ipynb"
out_ipynb = os.path.join(landscape_dir, "cryoDRGN_analyze_landscape.ipynb")
if not os.path.exists(out_ipynb):
logger.info("Creating jupyter notebook...")
ipynb = f"{cryodrgn._ROOT}/templates/cryoDRGN_analyze_landscape_template.ipynb"
ipynb = os.path.join(
cryodrgn._ROOT, "templates", "cryoDRGN_analyze_landscape_template.ipynb"
)
shutil.copyfile(ipynb, out_ipynb)
else:
logger.info(f"{out_ipynb} already exists. Skipping")

# Lazily look at the beginning of the notebook for the epoch number to update
with open(out_ipynb, "r") as f:
filter_ntbook = nbformat.read(f, as_version=nbformat.NO_CONVERT)

for cell in filter_ntbook["cells"]:
cell["source"] = cell["source"].replace("EPOCH = None", f"EPOCH = {args.epoch}")
cell["source"] = cell["source"].replace(
"WORKDIR = None", f'WORKDIR = "{args.workdir}"'
)
cell["source"] = cell["source"].replace(
"K = None", f"K = {args.training_volumes}"
)
cell["source"] = cell["source"].replace("M = None", f"M = {kmeans_K}")
cell["source"] = cell["source"].replace(
"linkage = None", f'linkage = "{link_method}"'
)

with open(out_ipynb, "w") as f:
nbformat.write(filter_ntbook, f)

logger.info(out_ipynb)
logger.info(f"Finished in {dt.now()-t1}")
21 changes: 10 additions & 11 deletions cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,17 @@
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import pickle\n",
"import subprocess\n",
"import os, sys\n",
"\n",
"from cryodrgn import mrc\n",
"from cryodrgn.mrcfile import parse_mrc\n",
"from cryodrgn import analysis\n",
"from cryodrgn import utils\n",
"from cryodrgn import dataset\n",
"from cryodrgn import ctf\n",
" \n",
" \n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import plotly.graph_objs as go\n",
Expand Down Expand Up @@ -54,12 +53,12 @@
"metadata": {},
"outputs": [],
"source": [
"EPOCH = None # change me if necessary!\n",
"WORKDIR = '..' # Directory with cryoDRGN outputs\n",
"EPOCH = None # change me if necessary!\n",
"WORKDIR = None # Directory with cryoDRGN outputs\n",
"\n",
"K = 1000 # Number of sketched volumes\n",
"M = 10 # Number of clusters\n",
"linkage = 'average'"
"K = None # Number of sketched volumes\n",
"M = None # Number of clusters\n",
"linkage = None # Linkage method used for clustering"
]
},
{
Expand Down Expand Up @@ -101,7 +100,7 @@
"metadata": {},
"outputs": [],
"source": [
"mask = mrc.parse_mrc(f'{landscape_dir}/mask.mrc')\n",
"mask = parse_mrc(f'{landscape_dir}/mask.mrc')\n",
"mask = mask[0].astype(bool)\n",
"print(f'{mask.sum()} out of {np.prod(mask.shape)} voxels included in mask')"
]
Expand Down Expand Up @@ -143,8 +142,8 @@
"source": [
"# Load volumes\n",
"'''\n",
"volm, _ = mrc.parse_mrc(f'kmeans{K}/vol_mean.mrc')\n",
"vols = np.array([mrc.parse_mrc(f'kmeans{K}/vol_{i:03d}.mrc')[0][mask] for i in range(K)])\n",
"volm, _ = parse_mrc(f'kmeans{K}/vol_mean.mrc')\n",
"vols = np.array([parse_mrc(f'kmeans{K}/vol_{i:03d}.mrc')[0][mask] for i in range(K)])\n",
"vols.shape\n",
"vols[vols<0]=0\n",
"'''"
Expand Down
81 changes: 55 additions & 26 deletions tests/test_reconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def test_analyze(self, tmpdir_factory, particles, poses, ctf, indices, epoch):
outdir = self.get_outdir(tmpdir_factory, particles, indices, poses, ctf)
parser = argparse.ArgumentParser()
analyze.add_args(parser)
args = parser.parse_args([outdir, str(epoch)])
analyze.main(args)
analyze.main(parser.parse_args([outdir, str(epoch)]))

assert os.path.exists(os.path.join(outdir, f"analyze.{epoch}"))

@pytest.mark.parametrize(
Expand Down Expand Up @@ -251,6 +251,26 @@ def test_landscape_full(
analyze_landscape_full.add_args(parser)
analyze_landscape_full.main(parser.parse_args(args))

@pytest.mark.parametrize("ctf", ["CTF-Test"], indirect=True)
def test_landscape_notebook(self, tmpdir_factory, particles, poses, ctf, indices):
"""Execute the demo Jupyter notebooks produced by landscape analysis."""
outdir = self.get_outdir(tmpdir_factory, particles, indices, poses, ctf)
orig_cwd = os.path.abspath(os.getcwd())
os.chdir(os.path.join(outdir, "landscape.3"))
notebook_fl = "cryoDRGN_analyze_landscape.ipynb"
assert os.path.exists(notebook_fl)

with open(notebook_fl) as ff:
nb_in = nbformat.read(ff, nbformat.NO_CONVERT)

try:
ExecutePreprocessor(timeout=600, kernel_name="python3").preprocess(nb_in)
except CellExecutionError as e:
os.chdir(orig_cwd)
raise e

os.chdir(orig_cwd)

@pytest.mark.parametrize(
"ctf, seed, steps, points",
[
Expand Down Expand Up @@ -479,19 +499,24 @@ def test_train_model(self, tmpdir_factory, particles, ctf, indices):
def test_analyze(self, tmpdir_factory, particles, ctf, indices):
"""Produce standard analyses for a particular epoch."""
outdir = self.get_outdir(tmpdir_factory, particles, indices, ctf)
args = analyze.add_args(argparse.ArgumentParser()).parse_args(
[
outdir,
"1", # Epoch number to analyze - 0-indexed
"--pc",
"3", # Number of principal component traversals to generate
"--ksample",
"10", # Number of kmeans samples to generate
"--vol-start-index",
"1",
]

parser = argparse.ArgumentParser()
analyze.add_args(parser)
analyze.main(
parser.parse_args(
[
outdir,
"1", # Epoch number to analyze - 0-indexed
"--pc",
"3", # Number of principal component traversals to generate
"--ksample",
"10", # Number of kmeans samples to generate
"--vol-start-index",
"1",
]
)
)
analyze.main(args)

assert os.path.exists(os.path.join(outdir, "analyze.1"))

@pytest.mark.parametrize("nb_lbl", ["cryoDRGN_figures"])
Expand Down Expand Up @@ -751,19 +776,23 @@ def test_analyze(self, tmpdir_factory, particles, indices, poses, ctf, datadir):
outdir = self.get_outdir(
tmpdir_factory, particles, poses, ctf, indices, datadir
)
args = analyze.add_args(argparse.ArgumentParser()).parse_args(
[
outdir,
"4", # Epoch number to analyze - 0-indexed
"--pc",
"3", # Number of principal component traversals to generate
"--ksample",
"2", # Number of kmeans samples to generate
"--vol-start-index",
"1",
]

parser = argparse.ArgumentParser()
analyze.add_args(parser)
analyze.main(
parser.parse_args(
[
outdir,
"4", # Epoch number to analyze - 0-indexed
"--pc",
"3", # Number of principal component traversals to generate
"--ksample",
"2", # Number of kmeans samples to generate
"--vol-start-index",
"1",
]
)
)
analyze.main(args)
assert os.path.exists(os.path.join(outdir, "analyze.4"))

@pytest.mark.parametrize(
Expand Down

0 comments on commit 0f9c25f

Please sign in to comment.