From 0f9c25f4a25eef088ed5b54820272ff7adc77eab Mon Sep 17 00:00:00 2001 From: Michal Grzadkowski Date: Fri, 11 Oct 2024 16:36:06 -0400 Subject: [PATCH] fixing issues with notebook generation for landscape analysis and adding relevant unit tests --- cryodrgn/commands/analyze.py | 12 +-- cryodrgn/commands/analyze_landscape_full.py | 58 ++++++++++--- .../cryoDRGN_analyze_landscape_template.ipynb | 21 +++-- tests/test_reconstruct.py | 81 +++++++++++++------ 4 files changed, 116 insertions(+), 56 deletions(-) diff --git a/cryodrgn/commands/analyze.py b/cryodrgn/commands/analyze.py index 4c987ea7..a4012d46 100644 --- a/cryodrgn/commands/analyze.py +++ b/cryodrgn/commands/analyze.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) -def add_args(parser): +def add_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( "workdir", type=os.path.abspath, help="Directory with cryoDRGN results" ) @@ -402,7 +402,8 @@ def gen_volumes(self, outdir, z_values): analysis.gen_volumes(self.weights, self.config, zfile, outdir, **self.vol_args) -def main(args): +def main(args: argparse.Namespace) -> None: + matplotlib.use("Agg") # non-interactive backend t1 = dt.now() E = args.epoch workdir = args.workdir @@ -527,10 +528,3 @@ def main(args): nbformat.write(filter_ntbook, f) logger.info(f"Finished in {dt.now() - t1}") - - -if __name__ == "__main__": - matplotlib.use("Agg") # non-interactive backend - parser = argparse.ArgumentParser(description=__doc__) - add_args(parser) - main(parser.parse_args()) diff --git a/cryodrgn/commands/analyze_landscape_full.py b/cryodrgn/commands/analyze_landscape_full.py index 3ed38dd4..b5eddf49 100644 --- a/cryodrgn/commands/analyze_landscape_full.py +++ b/cryodrgn/commands/analyze_landscape_full.py @@ -11,19 +11,20 @@ """ import argparse import os -import os.path import pprint import shutil from datetime import datetime as dt import logging +import nbformat + import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.data.dataloader import default_collate -import torch.optim as optim from sklearn.model_selection import train_test_split +from torch.utils.data.dataloader import default_collate from torch.utils.data import Dataset, DataLoader + import cryodrgn from cryodrgn import config, utils from cryodrgn.models import HetOnlyVAE, ResidLinearMLP @@ -63,6 +64,7 @@ def add_args(parser: argparse.ArgumentParser) -> None: group = parser.add_argument_group("Volume generation arguments") group.add_argument( "-N", + "--training-volumes", type=int, default=10000, help="Number of training volumes to generate (default: %(default)s)", @@ -178,10 +180,12 @@ def generate_and_map_volumes( zfile, cfg, weights, mask_mrc, pca_obj_pkl, landscape_dir, outdir, args ): # Sample z - logger.info(f"Sampling {args.N} particles from {zfile}") + logger.info(f"Sampling {args.training_volumes} particles from {zfile}") np.random.seed(args.seed) z_all = utils.load_pkl(zfile) - ind = np.array(sorted(np.random.choice(len(z_all), args.N, replace=False))) # type: ignore + ind = np.array( + sorted(np.random.choice(len(z_all), args.training_volumes, replace=False)) + ) # type: ignore z_sample = z_all[ind] utils.save_pkl(z_sample, f"{outdir}/z.sampled.pkl") utils.save_pkl(ind, f"{outdir}/ind.sampled.pkl") @@ -223,7 +227,7 @@ def generate_and_map_volumes( t1 = dt.now() embeddings = [] for i, zz in enumerate(z): - if i % 100 == 0: + if i % 1 == 0: logger.info(i) if args.downsample: @@ -250,7 +254,7 @@ def generate_and_map_volumes( embeddings = np.array(embeddings).reshape(len(z), -1).astype(np.float32) td = dt.now() - t1 - logger.info(f"Finished generating {args.N} volumes in {td}") + logger.info(f"Finished generating {args.training_volumes} volumes in {td}") return z, embeddings @@ -286,7 +290,7 @@ def train_model(x, y, outdir, zfile, args): device ) logger.info(model) - optimizer = optim.Adam(model.parameters(), lr=args.lr) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Train for epoch in range(1, args.epochs + 1): @@ -312,6 +316,7 @@ def train_model(x, y, outdir, zfile, args): def main(args: argparse.Namespace) -> None: t1 = dt.now() logger.info(args) + E = args.epoch workdir = args.workdir zfile = f"{workdir}/z.{E}.pkl" @@ -335,6 +340,17 @@ def main(args: argparse.Namespace) -> None: pca_obj_pkl ), f"{pca_obj_pkl} missing. Did you run cryodrgn analyze_landscape?" + kmeans_folder = [ + p for p in os.listdir(landscape_dir) if p.startswith("clustering_L2_") + ] + if len(kmeans_folder) == 0: + raise RuntimeError( + "No clustering folders `clustering_L2_` found. " + "Did you run cryodrgn analyze_landscape?" + ) + kmeans_folder = kmeans_folder[0] + link_method, kmeans_K = kmeans_folder.split("_")[-2:] + logger.info(f"Saving results to {outdir}") if not os.path.exists(outdir): os.mkdir(outdir) @@ -361,13 +377,35 @@ def main(args: argparse.Namespace) -> None: utils.save_pkl(embeddings_all, f"{outdir}/vol_pca_all.pkl") # Copy viz notebook - out_ipynb = f"{landscape_dir}/cryoDRGN_analyze_landscape.ipynb" + out_ipynb = os.path.join(landscape_dir, "cryoDRGN_analyze_landscape.ipynb") if not os.path.exists(out_ipynb): logger.info("Creating jupyter notebook...") - ipynb = f"{cryodrgn._ROOT}/templates/cryoDRGN_analyze_landscape_template.ipynb" + ipynb = os.path.join( + cryodrgn._ROOT, "templates", "cryoDRGN_analyze_landscape_template.ipynb" + ) shutil.copyfile(ipynb, out_ipynb) else: logger.info(f"{out_ipynb} already exists. Skipping") + # Lazily look at the beginning of the notebook for the epoch number to update + with open(out_ipynb, "r") as f: + filter_ntbook = nbformat.read(f, as_version=nbformat.NO_CONVERT) + + for cell in filter_ntbook["cells"]: + cell["source"] = cell["source"].replace("EPOCH = None", f"EPOCH = {args.epoch}") + cell["source"] = cell["source"].replace( + "WORKDIR = None", f'WORKDIR = "{args.workdir}"' + ) + cell["source"] = cell["source"].replace( + "K = None", f"K = {args.training_volumes}" + ) + cell["source"] = cell["source"].replace("M = None", f"M = {kmeans_K}") + cell["source"] = cell["source"].replace( + "linkage = None", f'linkage = "{link_method}"' + ) + + with open(out_ipynb, "w") as f: + nbformat.write(filter_ntbook, f) + logger.info(out_ipynb) logger.info(f"Finished in {dt.now()-t1}") diff --git a/cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb b/cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb index 95981184..c1586bfb 100644 --- a/cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb +++ b/cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb @@ -15,18 +15,17 @@ "metadata": {}, "outputs": [], "source": [ - "import pandas as pd\n", "import numpy as np\n", "import pickle\n", "import subprocess\n", "import os, sys\n", "\n", - "from cryodrgn import mrc\n", + "from cryodrgn.mrcfile import parse_mrc\n", "from cryodrgn import analysis\n", "from cryodrgn import utils\n", "from cryodrgn import dataset\n", "from cryodrgn import ctf\n", - " \n", + " \n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import plotly.graph_objs as go\n", @@ -54,12 +53,12 @@ "metadata": {}, "outputs": [], "source": [ - "EPOCH = None # change me if necessary!\n", - "WORKDIR = '..' # Directory with cryoDRGN outputs\n", + "EPOCH = None # change me if necessary!\n", + "WORKDIR = None # Directory with cryoDRGN outputs\n", "\n", - "K = 1000 # Number of sketched volumes\n", - "M = 10 # Number of clusters\n", - "linkage = 'average'" + "K = None # Number of sketched volumes\n", + "M = None # Number of clusters\n", + "linkage = None # Linkage method used for clustering" ] }, { @@ -101,7 +100,7 @@ "metadata": {}, "outputs": [], "source": [ - "mask = mrc.parse_mrc(f'{landscape_dir}/mask.mrc')\n", + "mask = parse_mrc(f'{landscape_dir}/mask.mrc')\n", "mask = mask[0].astype(bool)\n", "print(f'{mask.sum()} out of {np.prod(mask.shape)} voxels included in mask')" ] @@ -143,8 +142,8 @@ "source": [ "# Load volumes\n", "'''\n", - "volm, _ = mrc.parse_mrc(f'kmeans{K}/vol_mean.mrc')\n", - "vols = np.array([mrc.parse_mrc(f'kmeans{K}/vol_{i:03d}.mrc')[0][mask] for i in range(K)])\n", + "volm, _ = parse_mrc(f'kmeans{K}/vol_mean.mrc')\n", + "vols = np.array([parse_mrc(f'kmeans{K}/vol_{i:03d}.mrc')[0][mask] for i in range(K)])\n", "vols.shape\n", "vols[vols<0]=0\n", "'''" diff --git a/tests/test_reconstruct.py b/tests/test_reconstruct.py index 465654b5..f78d7f0b 100644 --- a/tests/test_reconstruct.py +++ b/tests/test_reconstruct.py @@ -127,8 +127,8 @@ def test_analyze(self, tmpdir_factory, particles, poses, ctf, indices, epoch): outdir = self.get_outdir(tmpdir_factory, particles, indices, poses, ctf) parser = argparse.ArgumentParser() analyze.add_args(parser) - args = parser.parse_args([outdir, str(epoch)]) - analyze.main(args) + analyze.main(parser.parse_args([outdir, str(epoch)])) + assert os.path.exists(os.path.join(outdir, f"analyze.{epoch}")) @pytest.mark.parametrize( @@ -251,6 +251,26 @@ def test_landscape_full( analyze_landscape_full.add_args(parser) analyze_landscape_full.main(parser.parse_args(args)) + @pytest.mark.parametrize("ctf", ["CTF-Test"], indirect=True) + def test_landscape_notebook(self, tmpdir_factory, particles, poses, ctf, indices): + """Execute the demo Jupyter notebooks produced by landscape analysis.""" + outdir = self.get_outdir(tmpdir_factory, particles, indices, poses, ctf) + orig_cwd = os.path.abspath(os.getcwd()) + os.chdir(os.path.join(outdir, "landscape.3")) + notebook_fl = "cryoDRGN_analyze_landscape.ipynb" + assert os.path.exists(notebook_fl) + + with open(notebook_fl) as ff: + nb_in = nbformat.read(ff, nbformat.NO_CONVERT) + + try: + ExecutePreprocessor(timeout=600, kernel_name="python3").preprocess(nb_in) + except CellExecutionError as e: + os.chdir(orig_cwd) + raise e + + os.chdir(orig_cwd) + @pytest.mark.parametrize( "ctf, seed, steps, points", [ @@ -479,19 +499,24 @@ def test_train_model(self, tmpdir_factory, particles, ctf, indices): def test_analyze(self, tmpdir_factory, particles, ctf, indices): """Produce standard analyses for a particular epoch.""" outdir = self.get_outdir(tmpdir_factory, particles, indices, ctf) - args = analyze.add_args(argparse.ArgumentParser()).parse_args( - [ - outdir, - "1", # Epoch number to analyze - 0-indexed - "--pc", - "3", # Number of principal component traversals to generate - "--ksample", - "10", # Number of kmeans samples to generate - "--vol-start-index", - "1", - ] + + parser = argparse.ArgumentParser() + analyze.add_args(parser) + analyze.main( + parser.parse_args( + [ + outdir, + "1", # Epoch number to analyze - 0-indexed + "--pc", + "3", # Number of principal component traversals to generate + "--ksample", + "10", # Number of kmeans samples to generate + "--vol-start-index", + "1", + ] + ) ) - analyze.main(args) + assert os.path.exists(os.path.join(outdir, "analyze.1")) @pytest.mark.parametrize("nb_lbl", ["cryoDRGN_figures"]) @@ -751,19 +776,23 @@ def test_analyze(self, tmpdir_factory, particles, indices, poses, ctf, datadir): outdir = self.get_outdir( tmpdir_factory, particles, poses, ctf, indices, datadir ) - args = analyze.add_args(argparse.ArgumentParser()).parse_args( - [ - outdir, - "4", # Epoch number to analyze - 0-indexed - "--pc", - "3", # Number of principal component traversals to generate - "--ksample", - "2", # Number of kmeans samples to generate - "--vol-start-index", - "1", - ] + + parser = argparse.ArgumentParser() + analyze.add_args(parser) + analyze.main( + parser.parse_args( + [ + outdir, + "4", # Epoch number to analyze - 0-indexed + "--pc", + "3", # Number of principal component traversals to generate + "--ksample", + "2", # Number of kmeans samples to generate + "--vol-start-index", + "1", + ] + ) ) - analyze.main(args) assert os.path.exists(os.path.join(outdir, "analyze.4")) @pytest.mark.parametrize(