From 00f255e7b97ba8bb0e463f90c0b5030a96bb60a0 Mon Sep 17 00:00:00 2001
From: DSilva27 <david.silvas@udea.edu.co>
Date: Thu, 12 Dec 2024 15:20:46 -0500
Subject: [PATCH] update svd branch

---
 src/cryo_challenge/_svd/svd_pipeline.py | 156 +++++++++++-------------
 tests/config_files/test_config_svd.yaml |   2 +-
 tests/test_svd.py                       |   2 +-
 3 files changed, 76 insertions(+), 84 deletions(-)

diff --git a/src/cryo_challenge/_svd/svd_pipeline.py b/src/cryo_challenge/_svd/svd_pipeline.py
index 1886a30..5df13b9 100644
--- a/src/cryo_challenge/_svd/svd_pipeline.py
+++ b/src/cryo_challenge/_svd/svd_pipeline.py
@@ -1,22 +1,15 @@
 import torch
-import os
 
 from .svd_utils import (
     compute_distance_matrix,
     compute_common_embedding,
     project_to_gt_embedding,
 )
-from .svd_plots import (
-    plot_distance_matrix,
-    plot_common_embedding,
-    plot_gt_embedding,
-    plot_common_eigenvectors,
-)
 from ..data._io.svd_io_utils import load_submissions_svd, load_gt_svd
 
 
 def run_svd_with_ref(config: dict):
-    outputs_path = os.path.dirname(config["output_params"]["output_file"])
+    # outputs_path = os.path.dirname(config["output_params"]["output_file"])
 
     submissions_data = load_submissions_svd(config)
     gt_data = load_gt_svd(config)
@@ -38,54 +31,53 @@ def run_svd_with_ref(config: dict):
     torch.save(results, config["output_params"]["output_file"])
 
     if config["output_params"]["generate_plots"]:
-        outputs_fname_nopath_noext = os.path.basename(
-            config["output_params"]["output_file"]
-        )
-        outputs_fname_nopath_noext = os.path.splitext(outputs_fname_nopath_noext)[0]
-        path_plots = os.path.join(outputs_path, f"plots_{outputs_fname_nopath_noext}")
-
-        os.makedirs(path_plots, exist_ok=True)
-
-        print("Plotting distance matrix")
-        plot_distance_matrix(
-            dist_mtx_results["dist_matrix"],
-            dist_mtx_results["labels"],
-            title="SVD Distance Matrix",
-            save_path=os.path.join(path_plots, "svd_distance_matrix.png"),
-        )
-
-        print("Plotting common embedding")
-        plot_common_embedding(
-            submissions_data,
-            common_embedding_results,
-            title="Common Embedding between submissions",
-            save_path=os.path.join(path_plots, "common_embedding.png"),
-        )
-
-        print("Plotting gt embedding")
-        plot_gt_embedding(
-            submissions_data,
-            gt_embedding_results,
-            title="",
-            save_path=os.path.join(path_plots, "gt_embedding.png"),
-        )
-
-        print("Plotting common eigenvectors")
-        plot_common_eigenvectors(
-            common_embedding_results["common_eigenvectors"],
-            title="Common Eigenvectors between submissions",
-            save_path=os.path.join(path_plots, "common_eigenvectors.png"),
-        )
-
-        raise Warning(
-            "Plots do not yet follow the same color scheme as the published paper"
+        raise NotImplementedError(
+            "Plots are currently turned off due to incompatibilities. Your results were saved right before this error triggered."
         )
+        # outputs_fname_nopath_noext = os.path.basename(
+        #     config["output_params"]["output_file"]
+        # )
+        # outputs_fname_nopath_noext = os.path.splitext(outputs_fname_nopath_noext)[0]
+        # path_plots = os.path.join(outputs_path, f"plots_{outputs_fname_nopath_noext}")
+
+        # os.makedirs(path_plots, exist_ok=True)
+
+        # print("Plotting distance matrix")
+        # plot_distance_matrix(
+        #     dist_mtx_results["dist_matrix"],
+        #     dist_mtx_results["labels"],
+        #     title="SVD Distance Matrix",
+        #     save_path=os.path.join(path_plots, "svd_distance_matrix.png"),
+        # )
+
+        # print("Plotting common embedding")
+        # plot_common_embedding(
+        #     submissions_data,
+        #     common_embedding_results,
+        #     title="Common Embedding between submissions",
+        #     save_path=os.path.join(path_plots, "common_embedding.png"),
+        # )
+
+        # print("Plotting gt embedding")
+        # plot_gt_embedding(
+        #     submissions_data,
+        #     gt_embedding_results,
+        #     title="",
+        #     save_path=os.path.join(path_plots, "gt_embedding.png"),
+        # )
+
+        # print("Plotting common eigenvectors")
+        # plot_common_eigenvectors(
+        #     common_embedding_results["common_eigenvectors"],
+        #     title="Common Eigenvectors between submissions",
+        #     save_path=os.path.join(path_plots, "common_eigenvectors.png"),
+        # )
 
     return
 
 
 def run_svd_noref(config: dict):
-    outputs_path = os.path.dirname(config["output_params"]["output_file"])
+    # outputs_path = os.path.dirname(config["output_params"]["output_file"])
 
     submissions_data = load_submissions_svd(config)
     dist_mtx_results = compute_distance_matrix(submissions_data)
@@ -102,38 +94,38 @@ def run_svd_noref(config: dict):
     torch.save(results, config["output_params"]["output_file"])
 
     if config["output_params"]["generate_plots"]:
-        outputs_fname_nopath_noext = os.path.basename(
-            config["output_params"]["output_file"]
-        )
-        outputs_fname_nopath_noext = os.path.splitext(outputs_fname_nopath_noext)[0]
-        path_plots = os.path.join(outputs_path, f"plots_{outputs_fname_nopath_noext}")
-        os.makedirs(path_plots, exist_ok=True)
-
-        print("Plotting distance matrix")
-
-        plot_distance_matrix(
-            dist_mtx_results["dist_matrix"],
-            dist_mtx_results["labels"],
-            "SVD Distance Matrix",
-            save_path=os.path.join(path_plots, "svd_distance_matrix.png"),
+        raise NotImplementedError(
+            "Plots are currently turned off due to incompatibilities. Your results were saved right before this error triggered."
         )
+        # outputs_fname_nopath_noext = os.path.basename(
+        #     config["output_params"]["output_file"]
+        # )
+        # outputs_fname_nopath_noext = os.path.splitext(outputs_fname_nopath_noext)[0]
+        # path_plots = os.path.join(outputs_path, f"plots_{outputs_fname_nopath_noext}")
+        # os.makedirs(path_plots, exist_ok=True)
+
+        # print("Plotting distance matrix")
+
+        # plot_distance_matrix(
+        #     dist_mtx_results["dist_matrix"],
+        #     dist_mtx_results["labels"],
+        #     "SVD Distance Matrix",
+        #     save_path=os.path.join(path_plots, "svd_distance_matrix.png"),
+        # )
+
+        # print("Plotting common embedding")
+        # plot_common_embedding(
+        #     submissions_data,
+        #     common_embedding_results,
+        #     "Common Embedding between submissions",
+        #     save_path=os.path.join(path_plots, "common_embedding.png"),
+        # )
+
+        # print("Plotting common eigenvectors")
+        # plot_common_eigenvectors(
+        #     common_embedding_results["common_eigenvectors"],
+        #     title="Common Eigenvectors between submissions",
+        #     save_path=os.path.join(path_plots, "common_eigenvectors.png"),
+        # )
 
-        print("Plotting common embedding")
-        plot_common_embedding(
-            submissions_data,
-            common_embedding_results,
-            "Common Embedding between submissions",
-            save_path=os.path.join(path_plots, "common_embedding.png"),
-        )
-
-        print("Plotting common eigenvectors")
-        plot_common_eigenvectors(
-            common_embedding_results["common_eigenvectors"],
-            title="Common Eigenvectors between submissions",
-            save_path=os.path.join(path_plots, "common_eigenvectors.png"),
-        )
-
-        raise Warning(
-            "Plots do not yet follow the same color scheme as the published paper"
-        )
     return
diff --git a/tests/config_files/test_config_svd.yaml b/tests/config_files/test_config_svd.yaml
index b24ad41..a11162d 100644
--- a/tests/config_files/test_config_svd.yaml
+++ b/tests/config_files/test_config_svd.yaml
@@ -18,4 +18,4 @@ gt_params: # optional, if provided there will be extra results
 output_params:
   output_file: tests/results/svd/svd_result.pt
   save_svd_data: True
-  generate_plots: True
+  generate_plots: False
diff --git a/tests/test_svd.py b/tests/test_svd.py
index ea166ea..50c9a9b 100644
--- a/tests/test_svd.py
+++ b/tests/test_svd.py
@@ -2,6 +2,6 @@
 from cryo_challenge._commands import run_svd
 
 
-def test_run_preprocessing():
+def test_run_svd():
     args = OmegaConf.create({"config": "tests/config_files/test_config_svd.yaml"})
     run_svd.main(args)