From 00f255e7b97ba8bb0e463f90c0b5030a96bb60a0 Mon Sep 17 00:00:00 2001 From: DSilva27 <david.silvas@udea.edu.co> Date: Thu, 12 Dec 2024 15:20:46 -0500 Subject: [PATCH] update svd branch --- src/cryo_challenge/_svd/svd_pipeline.py | 156 +++++++++++------------- tests/config_files/test_config_svd.yaml | 2 +- tests/test_svd.py | 2 +- 3 files changed, 76 insertions(+), 84 deletions(-) diff --git a/src/cryo_challenge/_svd/svd_pipeline.py b/src/cryo_challenge/_svd/svd_pipeline.py index 1886a30..5df13b9 100644 --- a/src/cryo_challenge/_svd/svd_pipeline.py +++ b/src/cryo_challenge/_svd/svd_pipeline.py @@ -1,22 +1,15 @@ import torch -import os from .svd_utils import ( compute_distance_matrix, compute_common_embedding, project_to_gt_embedding, ) -from .svd_plots import ( - plot_distance_matrix, - plot_common_embedding, - plot_gt_embedding, - plot_common_eigenvectors, -) from ..data._io.svd_io_utils import load_submissions_svd, load_gt_svd def run_svd_with_ref(config: dict): - outputs_path = os.path.dirname(config["output_params"]["output_file"]) + # outputs_path = os.path.dirname(config["output_params"]["output_file"]) submissions_data = load_submissions_svd(config) gt_data = load_gt_svd(config) @@ -38,54 +31,53 @@ def run_svd_with_ref(config: dict): torch.save(results, config["output_params"]["output_file"]) if config["output_params"]["generate_plots"]: - outputs_fname_nopath_noext = os.path.basename( - config["output_params"]["output_file"] - ) - outputs_fname_nopath_noext = os.path.splitext(outputs_fname_nopath_noext)[0] - path_plots = os.path.join(outputs_path, f"plots_{outputs_fname_nopath_noext}") - - os.makedirs(path_plots, exist_ok=True) - - print("Plotting distance matrix") - plot_distance_matrix( - dist_mtx_results["dist_matrix"], - dist_mtx_results["labels"], - title="SVD Distance Matrix", - save_path=os.path.join(path_plots, "svd_distance_matrix.png"), - ) - - print("Plotting common embedding") - plot_common_embedding( - submissions_data, - common_embedding_results, - title="Common Embedding between submissions", - save_path=os.path.join(path_plots, "common_embedding.png"), - ) - - print("Plotting gt embedding") - plot_gt_embedding( - submissions_data, - gt_embedding_results, - title="", - save_path=os.path.join(path_plots, "gt_embedding.png"), - ) - - print("Plotting common eigenvectors") - plot_common_eigenvectors( - common_embedding_results["common_eigenvectors"], - title="Common Eigenvectors between submissions", - save_path=os.path.join(path_plots, "common_eigenvectors.png"), - ) - - raise Warning( - "Plots do not yet follow the same color scheme as the published paper" + raise NotImplementedError( + "Plots are currently turned off due to incompatibilities. Your results were saved right before this error triggered." ) + # outputs_fname_nopath_noext = os.path.basename( + # config["output_params"]["output_file"] + # ) + # outputs_fname_nopath_noext = os.path.splitext(outputs_fname_nopath_noext)[0] + # path_plots = os.path.join(outputs_path, f"plots_{outputs_fname_nopath_noext}") + + # os.makedirs(path_plots, exist_ok=True) + + # print("Plotting distance matrix") + # plot_distance_matrix( + # dist_mtx_results["dist_matrix"], + # dist_mtx_results["labels"], + # title="SVD Distance Matrix", + # save_path=os.path.join(path_plots, "svd_distance_matrix.png"), + # ) + + # print("Plotting common embedding") + # plot_common_embedding( + # submissions_data, + # common_embedding_results, + # title="Common Embedding between submissions", + # save_path=os.path.join(path_plots, "common_embedding.png"), + # ) + + # print("Plotting gt embedding") + # plot_gt_embedding( + # submissions_data, + # gt_embedding_results, + # title="", + # save_path=os.path.join(path_plots, "gt_embedding.png"), + # ) + + # print("Plotting common eigenvectors") + # plot_common_eigenvectors( + # common_embedding_results["common_eigenvectors"], + # title="Common Eigenvectors between submissions", + # save_path=os.path.join(path_plots, "common_eigenvectors.png"), + # ) return def run_svd_noref(config: dict): - outputs_path = os.path.dirname(config["output_params"]["output_file"]) + # outputs_path = os.path.dirname(config["output_params"]["output_file"]) submissions_data = load_submissions_svd(config) dist_mtx_results = compute_distance_matrix(submissions_data) @@ -102,38 +94,38 @@ def run_svd_noref(config: dict): torch.save(results, config["output_params"]["output_file"]) if config["output_params"]["generate_plots"]: - outputs_fname_nopath_noext = os.path.basename( - config["output_params"]["output_file"] - ) - outputs_fname_nopath_noext = os.path.splitext(outputs_fname_nopath_noext)[0] - path_plots = os.path.join(outputs_path, f"plots_{outputs_fname_nopath_noext}") - os.makedirs(path_plots, exist_ok=True) - - print("Plotting distance matrix") - - plot_distance_matrix( - dist_mtx_results["dist_matrix"], - dist_mtx_results["labels"], - "SVD Distance Matrix", - save_path=os.path.join(path_plots, "svd_distance_matrix.png"), + raise NotImplementedError( + "Plots are currently turned off due to incompatibilities. Your results were saved right before this error triggered." ) + # outputs_fname_nopath_noext = os.path.basename( + # config["output_params"]["output_file"] + # ) + # outputs_fname_nopath_noext = os.path.splitext(outputs_fname_nopath_noext)[0] + # path_plots = os.path.join(outputs_path, f"plots_{outputs_fname_nopath_noext}") + # os.makedirs(path_plots, exist_ok=True) + + # print("Plotting distance matrix") + + # plot_distance_matrix( + # dist_mtx_results["dist_matrix"], + # dist_mtx_results["labels"], + # "SVD Distance Matrix", + # save_path=os.path.join(path_plots, "svd_distance_matrix.png"), + # ) + + # print("Plotting common embedding") + # plot_common_embedding( + # submissions_data, + # common_embedding_results, + # "Common Embedding between submissions", + # save_path=os.path.join(path_plots, "common_embedding.png"), + # ) + + # print("Plotting common eigenvectors") + # plot_common_eigenvectors( + # common_embedding_results["common_eigenvectors"], + # title="Common Eigenvectors between submissions", + # save_path=os.path.join(path_plots, "common_eigenvectors.png"), + # ) - print("Plotting common embedding") - plot_common_embedding( - submissions_data, - common_embedding_results, - "Common Embedding between submissions", - save_path=os.path.join(path_plots, "common_embedding.png"), - ) - - print("Plotting common eigenvectors") - plot_common_eigenvectors( - common_embedding_results["common_eigenvectors"], - title="Common Eigenvectors between submissions", - save_path=os.path.join(path_plots, "common_eigenvectors.png"), - ) - - raise Warning( - "Plots do not yet follow the same color scheme as the published paper" - ) return diff --git a/tests/config_files/test_config_svd.yaml b/tests/config_files/test_config_svd.yaml index b24ad41..a11162d 100644 --- a/tests/config_files/test_config_svd.yaml +++ b/tests/config_files/test_config_svd.yaml @@ -18,4 +18,4 @@ gt_params: # optional, if provided there will be extra results output_params: output_file: tests/results/svd/svd_result.pt save_svd_data: True - generate_plots: True + generate_plots: False diff --git a/tests/test_svd.py b/tests/test_svd.py index ea166ea..50c9a9b 100644 --- a/tests/test_svd.py +++ b/tests/test_svd.py @@ -2,6 +2,6 @@ from cryo_challenge._commands import run_svd -def test_run_preprocessing(): +def test_run_svd(): args = OmegaConf.create({"config": "tests/config_files/test_config_svd.yaml"}) run_svd.main(args)