Skip to content

Commit

Permalink
update svd branch
Browse files Browse the repository at this point in the history
  • Loading branch information
DSilva27 committed Dec 12, 2024
1 parent 6600822 commit 00f255e
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 84 deletions.
156 changes: 74 additions & 82 deletions src/cryo_challenge/_svd/svd_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
import torch
import os

from .svd_utils import (
compute_distance_matrix,
compute_common_embedding,
project_to_gt_embedding,
)
from .svd_plots import (
plot_distance_matrix,
plot_common_embedding,
plot_gt_embedding,
plot_common_eigenvectors,
)
from ..data._io.svd_io_utils import load_submissions_svd, load_gt_svd


def run_svd_with_ref(config: dict):
outputs_path = os.path.dirname(config["output_params"]["output_file"])
# outputs_path = os.path.dirname(config["output_params"]["output_file"])

submissions_data = load_submissions_svd(config)
gt_data = load_gt_svd(config)
Expand All @@ -38,54 +31,53 @@ def run_svd_with_ref(config: dict):
torch.save(results, config["output_params"]["output_file"])

if config["output_params"]["generate_plots"]:
outputs_fname_nopath_noext = os.path.basename(
config["output_params"]["output_file"]
)
outputs_fname_nopath_noext = os.path.splitext(outputs_fname_nopath_noext)[0]
path_plots = os.path.join(outputs_path, f"plots_{outputs_fname_nopath_noext}")

os.makedirs(path_plots, exist_ok=True)

print("Plotting distance matrix")
plot_distance_matrix(
dist_mtx_results["dist_matrix"],
dist_mtx_results["labels"],
title="SVD Distance Matrix",
save_path=os.path.join(path_plots, "svd_distance_matrix.png"),
)

print("Plotting common embedding")
plot_common_embedding(
submissions_data,
common_embedding_results,
title="Common Embedding between submissions",
save_path=os.path.join(path_plots, "common_embedding.png"),
)

print("Plotting gt embedding")
plot_gt_embedding(
submissions_data,
gt_embedding_results,
title="",
save_path=os.path.join(path_plots, "gt_embedding.png"),
)

print("Plotting common eigenvectors")
plot_common_eigenvectors(
common_embedding_results["common_eigenvectors"],
title="Common Eigenvectors between submissions",
save_path=os.path.join(path_plots, "common_eigenvectors.png"),
)

raise Warning(
"Plots do not yet follow the same color scheme as the published paper"
raise NotImplementedError(
"Plots are currently turned off due to incompatibilities. Your results were saved right before this error triggered."
)
# outputs_fname_nopath_noext = os.path.basename(
# config["output_params"]["output_file"]
# )
# outputs_fname_nopath_noext = os.path.splitext(outputs_fname_nopath_noext)[0]
# path_plots = os.path.join(outputs_path, f"plots_{outputs_fname_nopath_noext}")

# os.makedirs(path_plots, exist_ok=True)

# print("Plotting distance matrix")
# plot_distance_matrix(
# dist_mtx_results["dist_matrix"],
# dist_mtx_results["labels"],
# title="SVD Distance Matrix",
# save_path=os.path.join(path_plots, "svd_distance_matrix.png"),
# )

# print("Plotting common embedding")
# plot_common_embedding(
# submissions_data,
# common_embedding_results,
# title="Common Embedding between submissions",
# save_path=os.path.join(path_plots, "common_embedding.png"),
# )

# print("Plotting gt embedding")
# plot_gt_embedding(
# submissions_data,
# gt_embedding_results,
# title="",
# save_path=os.path.join(path_plots, "gt_embedding.png"),
# )

# print("Plotting common eigenvectors")
# plot_common_eigenvectors(
# common_embedding_results["common_eigenvectors"],
# title="Common Eigenvectors between submissions",
# save_path=os.path.join(path_plots, "common_eigenvectors.png"),
# )

return


def run_svd_noref(config: dict):
outputs_path = os.path.dirname(config["output_params"]["output_file"])
# outputs_path = os.path.dirname(config["output_params"]["output_file"])

submissions_data = load_submissions_svd(config)
dist_mtx_results = compute_distance_matrix(submissions_data)
Expand All @@ -102,38 +94,38 @@ def run_svd_noref(config: dict):
torch.save(results, config["output_params"]["output_file"])

if config["output_params"]["generate_plots"]:
outputs_fname_nopath_noext = os.path.basename(
config["output_params"]["output_file"]
)
outputs_fname_nopath_noext = os.path.splitext(outputs_fname_nopath_noext)[0]
path_plots = os.path.join(outputs_path, f"plots_{outputs_fname_nopath_noext}")
os.makedirs(path_plots, exist_ok=True)

print("Plotting distance matrix")

plot_distance_matrix(
dist_mtx_results["dist_matrix"],
dist_mtx_results["labels"],
"SVD Distance Matrix",
save_path=os.path.join(path_plots, "svd_distance_matrix.png"),
raise NotImplementedError(
"Plots are currently turned off due to incompatibilities. Your results were saved right before this error triggered."
)
# outputs_fname_nopath_noext = os.path.basename(
# config["output_params"]["output_file"]
# )
# outputs_fname_nopath_noext = os.path.splitext(outputs_fname_nopath_noext)[0]
# path_plots = os.path.join(outputs_path, f"plots_{outputs_fname_nopath_noext}")
# os.makedirs(path_plots, exist_ok=True)

# print("Plotting distance matrix")

# plot_distance_matrix(
# dist_mtx_results["dist_matrix"],
# dist_mtx_results["labels"],
# "SVD Distance Matrix",
# save_path=os.path.join(path_plots, "svd_distance_matrix.png"),
# )

# print("Plotting common embedding")
# plot_common_embedding(
# submissions_data,
# common_embedding_results,
# "Common Embedding between submissions",
# save_path=os.path.join(path_plots, "common_embedding.png"),
# )

# print("Plotting common eigenvectors")
# plot_common_eigenvectors(
# common_embedding_results["common_eigenvectors"],
# title="Common Eigenvectors between submissions",
# save_path=os.path.join(path_plots, "common_eigenvectors.png"),
# )

print("Plotting common embedding")
plot_common_embedding(
submissions_data,
common_embedding_results,
"Common Embedding between submissions",
save_path=os.path.join(path_plots, "common_embedding.png"),
)

print("Plotting common eigenvectors")
plot_common_eigenvectors(
common_embedding_results["common_eigenvectors"],
title="Common Eigenvectors between submissions",
save_path=os.path.join(path_plots, "common_eigenvectors.png"),
)

raise Warning(
"Plots do not yet follow the same color scheme as the published paper"
)
return
2 changes: 1 addition & 1 deletion tests/config_files/test_config_svd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ gt_params: # optional, if provided there will be extra results
output_params:
output_file: tests/results/svd/svd_result.pt
save_svd_data: True
generate_plots: True
generate_plots: False
2 changes: 1 addition & 1 deletion tests/test_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from cryo_challenge._commands import run_svd


def test_run_preprocessing():
def test_run_svd():
args = OmegaConf.create({"config": "tests/config_files/test_config_svd.yaml"})
run_svd.main(args)

0 comments on commit 00f255e

Please sign in to comment.