Skip to content

Commit

Permalink
Merge pull request #372 from ml-struct-bio/v3.3.1
Browse files Browse the repository at this point in the history
v3.3.1: fixes to backprojection and tilt --ind; per tomo star filtering
  • Loading branch information
michal-g authored May 9, 2024
2 parents d30e559 + 1a63702 commit 9dd0915
Show file tree
Hide file tree
Showing 32 changed files with 1,097 additions and 547 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/beta_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Beta Release
on:
push:
tags:
- '[0-9]+.[0-9]+.[0-9]+*'
- '[0-9]+\.[0-9]+\.[0-9]+-*'

jobs:
beta-release:
Expand All @@ -21,10 +21,9 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: 3.9
python-version: '3.9'

- name: Release to TestPyPI
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: 3.9
python-version: '3.9'

- name: Checkout repository
uses: actions/checkout@v4
Expand Down
13 changes: 9 additions & 4 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ name: Release
on:
push:
branches: [ main ]
tags:
- '[0-9]+.[0-9]+.[0-9]+*'

jobs:
release:
Expand All @@ -22,10 +20,17 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: 3.9
python-version: '3.9'

- name: Check Tag
id: check-tag
run: |
if [[ ${{ github.event.ref }} =~ ^refs/tags/[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "match=true" >> $GITHUB_OUTPUT
fi
- name: Release to pypi
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
if: steps.check-tag.outputs.match == 'true'
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_MAIN_TOKEN }}
Expand Down
36 changes: 36 additions & 0 deletions .github/workflows/style.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: Code Linting

on:
push:
branches: [ main, develop ]
tags:
- '[0-9]+\.[0-9]+\.[0-9]+'
- '[0-9]+\.[0-9]+\.[0-9]+-*'
pull_request:
branches: [ main, develop ]

jobs:
run_tests:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Install cryoDRGN with dev dependencies
run: |
python3 -m pip install .[dev]
- name: Run pre-commit checks
run: |
pre-commit run --all-files --show-diff-on-failure
- name: Run Pyright
run: |
pyright --version
#pyright
21 changes: 8 additions & 13 deletions .github/workflows/main.yml → .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
name: Continuous integration testing
name: CI Testing

on:
push:
branches: [ main, develop ]
tags:
- '[0-9]+\.[0-9]+\.[0-9]+'
- '[0-9]+\.[0-9]+\.[0-9]+-*'
pull_request:
branches: [ main, develop ]

Expand All @@ -24,19 +27,11 @@ jobs:
with:
python-version: ${{ matrix.python }}

- name: Install cryoDRGN with dev dependencies
- name: Install cryoDRGN with pytest dependencies
run: |
python3 -m pip install .[dev]
- name: Pre-commit checks
run: |
pre-commit run --all-files --show-diff-on-failure
- name: Pyright
run: |
pyright --version
#pyright
python3 -m pip install pytest-xdist
python3 -m pip install .
- name: Pytest
run: |
pytest -v
pytest -v -n auto --dist=loadscope
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
![pypi-downloads](https://img.shields.io/pypi/dm/cryodrgn?style=flat&label=PyPI%20Downloads&logo=pypi&labelColor=%23FFF8EC)
![pypi-downloads](https://img.shields.io/pypi/dm/cryodrgn?style=flat&label=PyPI%20Downloads&logo=pypi&logoColor=%233775A9&labelColor=%23FFF8EC)
![stable-release](https://img.shields.io/pypi/v/cryodrgn?style=flat&logo=pypi&logoColor=%233775A9&logoSize=auto&label=stable%20release&labelColor=%23FFF8EC)
![beta-release](https://img.shields.io/pypi/v/cryodrgn?pypiBaseUrl=https%3A%2F%2Ftest.pypi.org&style=flat&logo=pypi&logoColor=%233775A9&logoSize=auto&label=beta%20release&labelColor=%23FFF8EC)
![ci-test](https://github.com/ml-struct-bio/cryodrgn/actions/workflows/main.yml/badge.svg)
![grading](https://img.shields.io/codefactor/grade/github/ml-struct-bio/cryodrgn/main?style=flat&logo=codefactor&logoColor=%23F44A6A&logoSize=auto&label=CodeFactor%20Grade&labelColor=%23FFF8EC)
![ci-test](https://github.com/ml-struct-bio/cryodrgn/actions/workflows/tests.yml/badge.svg)


# :snowflake::dragon: cryoDRGN: Deep Reconstructing Generative Networks for cryo-EM and cryo-ET heterogeneous reconstruction
Expand Down Expand Up @@ -197,7 +198,7 @@ The official version 1.0 release. This version introduces several new tools for

You can alternatively install a newer, less stable, development version of `cryodrgn` using our beta release channel:

(cryodrgn) $ pip install -i https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ "cryodrgn<=3.3.0" --pre
(cryodrgn) $ pip install -i https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ cryodrgn --pre

More installation instructions are found in the [documentation](https://ez-lab.gitbook.io/cryodrgn/installation).

Expand Down
13 changes: 6 additions & 7 deletions cryodrgn/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,19 +496,18 @@ def ipy_plot_interactive(df, opacity=0.3):
text = [f"index {i}" for i in df.index] # hovertext

xaxis, yaxis = df.columns[0], df.columns[1]
plt_size = max(1.7, 53 / df.shape[0] ** 0.31)
plt_mrk = dict(
size=plt_size, opacity=opacity, color=np.arange(len(df)), colorscale="hsv"
)
f = go.FigureWidget(
[
go.Scattergl(
x=df[xaxis],
y=df[yaxis],
mode="markers",
text=text,
marker=dict(
size=2, opacity=opacity, color=np.arange(len(df)), colorscale="hsv"
),
x=df[xaxis], y=df[yaxis], mode="markers", text=text, marker=plt_mrk
)
]
)

scatter = f.data[0]
f.update_layout(xaxis_title=xaxis, yaxis_title=yaxis)
f.layout.dragmode = "lasso"
Expand Down
103 changes: 67 additions & 36 deletions cryodrgn/commands/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def analyze_zN(
plt.legend()
plt.tight_layout()
plt.savefig(f"{outdir}/learning_curve_epoch{epoch}.png")
plt.close()

def plt_pc_labels(x=0, y=1):
plt.xlabel(f"PC{x+1} ({pca.explained_variance_ratio_[x]:.2f})")
Expand All @@ -172,18 +173,24 @@ def plt_umap_labels_jointplot(g):
plt_pc_labels()
plt.tight_layout()
plt.savefig(f"{outdir}/z_pca.png")
plt.close()

# PCA -- Style 2 -- Scatter, with marginals
g = sns.jointplot(x=pc[:, 0], y=pc[:, 1], alpha=0.1, s=1, rasterized=True, height=4)
plt_pc_labels_jointplot(g)
plt.tight_layout()
plt.savefig(f"{outdir}/z_pca_marginals.png")
plt.close()

# PCA -- Style 3 -- Hexbin
g = sns.jointplot(x=pc[:, 0], y=pc[:, 1], height=4, kind="hex")
plt_pc_labels_jointplot(g)
plt.tight_layout()
plt.savefig(f"{outdir}/z_pca_hexbin.png")
try:
g = sns.jointplot(x=pc[:, 0], y=pc[:, 1], height=4, kind="hex")
plt_pc_labels_jointplot(g)
plt.tight_layout()
plt.savefig(f"{outdir}/z_pca_hexbin.png")
plt.close()
except ZeroDivisionError:
print("Data too small to produce hexbins!")

if umap_emb is not None:
# Style 1 -- Scatter
Expand All @@ -192,25 +199,34 @@ def plt_umap_labels_jointplot(g):
plt_umap_labels()
plt.tight_layout()
plt.savefig(f"{outdir}/umap.png")
plt.close()

# Style 2 -- Scatter with marginal distributions
g = sns.jointplot(
x=umap_emb[:, 0],
y=umap_emb[:, 1],
alpha=0.1,
s=1,
rasterized=True,
height=4,
)
plt_umap_labels_jointplot(g)
plt.tight_layout()
plt.savefig(f"{outdir}/umap_marginals.png")
try:
g = sns.jointplot(
x=umap_emb[:, 0],
y=umap_emb[:, 1],
alpha=0.1,
s=1,
rasterized=True,
height=4,
)
plt_umap_labels_jointplot(g)
plt.tight_layout()
plt.savefig(f"{outdir}/umap_marginals.png")
plt.close()
except ZeroDivisionError:
logger.warning("Data too for marginal distribution scatterplots!")

# Style 3 -- Hexbin / heatmap
g = sns.jointplot(x=umap_emb[:, 0], y=umap_emb[:, 1], kind="hex", height=4)
plt_umap_labels_jointplot(g)
plt.tight_layout()
plt.savefig(f"{outdir}/umap_hexbin.png")
try:
g = sns.jointplot(x=umap_emb[:, 0], y=umap_emb[:, 1], kind="hex", height=4)
plt_umap_labels_jointplot(g)
plt.tight_layout()
plt.savefig(f"{outdir}/umap_hexbin.png")
plt.close()
except ZeroDivisionError:
logger.warning("Data too small to generate UMAP hexbins!")

# Plot kmeans sample points
colors = analysis._get_chimerax_colors(K)
Expand All @@ -224,6 +240,7 @@ def plt_umap_labels_jointplot(g):
plt_pc_labels()
plt.tight_layout()
plt.savefig(f"{outdir}/kmeans{K}/z_pca.png")
plt.close()

g = analysis.scatter_annotate_hex(
pc[:, 0],
Expand All @@ -235,6 +252,7 @@ def plt_umap_labels_jointplot(g):
plt_pc_labels_jointplot(g)
plt.tight_layout()
plt.savefig(f"{outdir}/kmeans{K}/z_pca_hex.png")
plt.close()

if umap_emb is not None:
analysis.scatter_annotate(
Expand All @@ -247,17 +265,22 @@ def plt_umap_labels_jointplot(g):
plt_umap_labels()
plt.tight_layout()
plt.savefig(f"{outdir}/kmeans{K}/umap.png")
plt.close()

g = analysis.scatter_annotate_hex(
umap_emb[:, 0],
umap_emb[:, 1],
centers_ind=centers_ind,
annotate=True,
colors=colors,
)
plt_umap_labels_jointplot(g)
plt.tight_layout()
plt.savefig(f"{outdir}/kmeans{K}/umap_hex.png")
try:
g = analysis.scatter_annotate_hex(
umap_emb[:, 0],
umap_emb[:, 1],
centers_ind=centers_ind,
annotate=True,
colors=colors,
)
plt_umap_labels_jointplot(g)
plt.tight_layout()
plt.savefig(f"{outdir}/kmeans{K}/umap_hex.png")
plt.close()
except ZeroDivisionError:
logger.warning("Data too small to generate UMAP annotated hexes!")

# Plot PC trajectories
for i in range(num_pcs):
Expand All @@ -274,12 +297,13 @@ def plt_umap_labels_jointplot(g):
plt_umap_labels()
plt.tight_layout()
plt.savefig(f"{outdir}/pc{i+1}/umap.png")
plt.close()

# UMAP, with PC traversal
z_pc_on_data, pc_ind = analysis.get_nearest_point(z, z_pc)
dists = ((z_pc_on_data - z_pc) ** 2).sum(axis=1) ** 0.5
if np.any(dists > 2):
logger.warn(
logger.warning(
f"Warning: PC{i+1} point locations in UMAP plot may be inaccurate"
)
plt.figure(figsize=(4, 4))
Expand All @@ -295,6 +319,7 @@ def plt_umap_labels_jointplot(g):
plt_umap_labels()
plt.tight_layout()
plt.savefig(f"{outdir}/pc{i+1}/umap_traversal.png")
plt.close()

# UMAP, with PC traversal, connected
plt.figure(figsize=(4, 4))
Expand All @@ -311,6 +336,7 @@ def plt_umap_labels_jointplot(g):
plt_umap_labels()
plt.tight_layout()
plt.savefig(f"{outdir}/pc{i+1}/umap_traversal_connected.png")
plt.close()

# 10 points, from 5th to 95th percentile of PC1 values
t = np.linspace(start, end, 10, endpoint=True)
Expand All @@ -325,6 +351,7 @@ def plt_umap_labels_jointplot(g):
plt_pc_labels(i, i + 1)
plt.tight_layout()
plt.savefig(f"{outdir}/pc{i+1}/pca_traversal.png")
plt.close()

if i > 0 and i == num_pcs - 1:
g = sns.jointplot(
Expand All @@ -340,6 +367,7 @@ def plt_umap_labels_jointplot(g):
plt_pc_labels_jointplot(g)
plt.tight_layout()
plt.savefig(f"{outdir}/pc{i+1}/pca_traversal_hex.png")
plt.close()


class VolumeGenerator:
Expand All @@ -354,8 +382,8 @@ def __init__(self, weights, config, vol_args={}, skip_vol=False):
def gen_volumes(self, outdir, z_values):
if self.skip_vol:
return
if not os.path.exists(outdir):
os.makedirs(outdir)

os.makedirs(outdir, exist_ok=True)
zfile = f"{outdir}/z_values.txt"
np.savetxt(zfile, z_values)
analysis.gen_volumes(self.weights, self.config, zfile, outdir, **self.vol_args)
Expand Down Expand Up @@ -475,10 +503,13 @@ def main(args):
# lazily look at the beginning of the notebook for the epoch number to update
with open(nb_outfile, "r") as f:
filter_ntbook = nbformat.read(f, as_version=nbformat.NO_CONVERT)
for i in range(5):
filter_ntbook["cells"][i]["source"] = filter_ntbook["cells"][i][
"source"
].replace("EPOCH = None", f"EPOCH = {epoch}")

for cell in filter_ntbook["cells"]:
cell["source"] = cell["source"].replace("EPOCH = None", f"EPOCH = {epoch}")
cell["source"] = cell["source"].replace(
"KMEANS = None", f"KMEANS = {args.ksample}"
)

with open(nb_outfile, "w") as f:
nbformat.write(filter_ntbook, f)

Expand Down
Loading

0 comments on commit 9dd0915

Please sign in to comment.