diff --git a/src/cryo_challenge/_commands/run_svd.py b/src/cryo_challenge/_commands/run_svd.py index 3a73204..e74df38 100644 --- a/src/cryo_challenge/_commands/run_svd.py +++ b/src/cryo_challenge/_commands/run_svd.py @@ -35,7 +35,7 @@ def main(args): with open(args.config, "r") as file: config = yaml.safe_load(file) - config = SVDConfig(**config).dict() + config = SVDConfig(**config).model_dump() warnexists(config["output_params"]["output_file"]) mkbasedir(os.path.dirname(config["output_params"]["output_file"])) diff --git a/src/cryo_challenge/_svd/svd_pipeline.py b/src/cryo_challenge/_svd/svd_pipeline.py index 50ba226..adccb2d 100644 --- a/src/cryo_challenge/_svd/svd_pipeline.py +++ b/src/cryo_challenge/_svd/svd_pipeline.py @@ -6,7 +6,12 @@ compute_common_embedding, project_to_gt_embedding, ) -from .svd_plots import plot_distance_matrix, plot_common_embedding, plot_gt_embedding +from .svd_plots import ( + plot_distance_matrix, + plot_common_embedding, + plot_gt_embedding, + plot_common_eigenvectors, +) from ..data._io.svd_io_utils import load_submissions_svd, load_gt_svd @@ -33,25 +38,43 @@ def run_svd_with_ref(config: dict): torch.save(results, config["output_params"]["output_file"]) if config["output_params"]["generate_plots"]: + outputs_fname_nopath_noext = os.path.basename( + config["output_params"]["output_file"] + ) + outputs_fname_nopath_noext = os.path.splitext(outputs_fname_nopath_noext)[0] + path_plots = os.path.join(outputs_path, f"plots_{outputs_fname_nopath_noext}") + + os.makedirs(path_plots, exist_ok=True) + + print("Plotting distance matrix") plot_distance_matrix( dist_mtx_results["dist_matrix"], dist_mtx_results["labels"], - "SVD Distance Matrix", - save_path=os.path.join(outputs_path, "svd_distance_matrix.png"), + title="SVD Distance Matrix", + save_path=os.path.join(path_plots, "svd_distance_matrix.png"), ) + print("Plotting common embedding") plot_common_embedding( submissions_data, common_embedding_results, - "Common Embedding between submissions", - save_path=os.path.join(outputs_path, "common_embedding.png"), + title="Common Embedding between submissions", + save_path=os.path.join(path_plots, "common_embedding.png"), ) + print("Plotting gt embedding") plot_gt_embedding( submissions_data, gt_embedding_results, - "", - save_path=os.path.join(outputs_path, "gt_embedding.png"), + title="", + save_path=os.path.join(path_plots, "gt_embedding.png"), + ) + + print("Plotting common eigenvectors") + plot_common_eigenvectors( + common_embedding_results["common_eigenvectors"], + title="Common Eigenvectors between submissions", + save_path=os.path.join(path_plots, "common_eigenvectors.png"), ) return @@ -62,11 +85,11 @@ def run_svd_noref(config: dict): submissions_data = load_submissions_svd(config) dist_mtx_results = compute_distance_matrix(submissions_data) - common_embeddings_results = compute_common_embedding(submissions_data) + common_embedding_results = compute_common_embedding(submissions_data) results = { "distance_matrix_results": dist_mtx_results, - "common_embedding_results": common_embeddings_results, + "common_embedding_results": common_embedding_results, } if config["output_params"]["save_svd_data"]: @@ -75,18 +98,35 @@ def run_svd_noref(config: dict): torch.save(results, config["output_params"]["output_file"]) if config["output_params"]["generate_plots"]: + outputs_fname_nopath_noext = os.path.basename( + config["output_params"]["output_file"] + ) + outputs_fname_nopath_noext = os.path.splitext(outputs_fname_nopath_noext)[0] + path_plots = os.path.join(outputs_path, f"plots_{outputs_fname_nopath_noext}") + os.makedirs(path_plots, exist_ok=True) + + print("Plotting distance matrix") + plot_distance_matrix( dist_mtx_results["dist_matrix"], dist_mtx_results["labels"], "SVD Distance Matrix", - save_path=os.path.join(outputs_path, "svd_distance_matrix.png"), + save_path=os.path.join(path_plots, "svd_distance_matrix.png"), ) + print("Plotting common embedding") plot_common_embedding( submissions_data, - common_embeddings_results, + common_embedding_results, "Common Embedding between submissions", - save_path=os.path.join(outputs_path, "common_embedding.png"), + save_path=os.path.join(path_plots, "common_embedding.png"), + ) + + print("Plotting common eigenvectors") + plot_common_eigenvectors( + common_embedding_results["common_eigenvectors"], + title="Common Eigenvectors between submissions", + save_path=os.path.join(path_plots, "common_eigenvectors.png"), ) return diff --git a/src/cryo_challenge/_svd/svd_plots.py b/src/cryo_challenge/_svd/svd_plots.py index 5cc7770..965cbd0 100644 --- a/src/cryo_challenge/_svd/svd_plots.py +++ b/src/cryo_challenge/_svd/svd_plots.py @@ -18,10 +18,7 @@ # "Rocky Road": {"color": "#4daf4a", "marker": "*"}, # } -COLORS = sns.color_palette("Set3", 12) -COLORS = [color for color in COLORS.as_hex()] - -MARKERS = ["o", "v", "^", "<", ">", "D", "x", "*", "s", "p", "P", "*"] +MARKERS = ["o", "v", "^", "<", ">", "D", "X", "*", "s", "p", "P", "*", "h", "H"] LABELS = [ "Ground Truth", "Cookie Dough", @@ -35,17 +32,24 @@ "Salted Caramel", "Chocolate Chip", "Rocky Road", + "Pina Colada", + "Bubble Gum", ] +assert len(MARKERS) >= len(LABELS) + +COLORS = sns.color_palette("Set3", len(LABELS)) +COLORS = [color for color in COLORS.as_hex()] + PLOT_SETUP = {} -for i in range(12): +for i in range(len(LABELS)): PLOT_SETUP[LABELS[i]] = { "color": COLORS[i], "marker": MARKERS[i], } -PLOT_SETUP["gt_left"] = {"color": "#e41a1c", "marker": "o"} +PLOT_SETUP["Coffee"] = {"color": "#e41a1c", "marker": "o"} PLOT_SETUP["gt_right"] = {"color": "#377eb8", "marker": "v"} @@ -58,8 +62,8 @@ def plot_distance_matrix(dist_matrix, labels, title="", save_path=None): ax.set_title(title) if save_path is not None: - plt.savefig(save_path) - plt.show() + plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1) + return @@ -77,7 +81,8 @@ def plot_common_embedding( plot_setup = {} for i, label in enumerate(labels): for possible_label in PLOT_SETUP.keys(): - if label in possible_label: + # print(label, possible_label) + if possible_label in label: plot_setup[label] = PLOT_SETUP[possible_label] for label in labels: @@ -106,19 +111,17 @@ def plot_common_embedding( n_rows = np.ceil(n_rows).astype(int) n_cols = np.ceil(len(labels) / n_rows).astype(int) - print(n_cols, n_rows) fig, ax = plt.subplots( n_rows, n_cols, figsize=(n_cols * 5, n_rows * 3), sharex=True, sharey=True ) if n_rows == 1 and n_cols == 1: ax = np.array([ax]) - print(all_embeddings.shape) for i in range(len(labels)): sns.kdeplot( x=all_embeddings[:, pc1], y=all_embeddings[:, pc2], - cmap="viridis", + cmap="gray", fill=True, cbar=False, ax=ax.flatten()[i], @@ -130,7 +133,7 @@ def plot_common_embedding( sns.kdeplot( x=all_embeddings[:, pc1], y=all_embeddings[:, pc2], - cmap="viridis", + cmap="gray", fill=True, cbar=False, ax=ax.flatten()[len(labels)], @@ -163,8 +166,8 @@ def plot_common_embedding( label=str(i + 1) + ". " + labels[i], ) - ax.flatten()[i].set_xticks([]) - ax.flatten()[i].set_yticks([]) + # ax.flatten()[i].set_xticks([]) + # ax.flatten()[i].set_yticks([]) if i >= n_rows: ax.flatten()[i].set_xlabel(f"Z{pc1 + 1}", fontsize=12) @@ -215,7 +218,7 @@ def plot_common_embedding( ) if save_path is not None: - plt.savefig(save_path) + plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1) return @@ -237,19 +240,21 @@ def compute_gt_dist(z): fig, ax = plt.subplots( n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3), sharex=True, sharey=True ) + if n_rows == 1 and n_cols == 1: + ax = np.array([ax]) plot_setup = {} for i, label in enumerate(submissions_data.keys()): for possible_label in PLOT_SETUP.keys(): - if label in possible_label: + if possible_label in label: plot_setup[label] = PLOT_SETUP[possible_label] for label in submissions_data.keys(): if label not in plot_setup.keys(): raise ValueError(f"Label {label} not found in PLOT_SETUP") - low_gt = -227.927103122416 - high_gt = 214.014930744738 + low_gt = -231.62100638454024 + high_gt = 243.32448171011487 Z = np.linspace(low_gt, high_gt, gt_embedding_results["gt_embedding"].shape[0]) x_axis = np.linspace( torch.min(gt_embedding_results["gt_embedding"][:, 0]), @@ -327,6 +332,61 @@ def compute_gt_dist(z): ) if save_path is not None: - plt.savefig(save_path) + plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1) + + return + + +def plot_common_eigenvectors( + common_eigenvectors, n_eig_to_plot=None, title="", save_path=None +): + n_eig_to_plot = min(10, len(common_eigenvectors)) + n_cols = 5 + n_rows = int(np.ceil(n_eig_to_plot / n_cols)) + + fig, ax = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 5)) + + box_size = int(round((common_eigenvectors[0].shape[-1]) ** (1 / 3))) + for i in range(n_eig_to_plot): + eigvol = common_eigenvectors[i].reshape(box_size, box_size, box_size) + + mask_small = torch.where(torch.abs(eigvol) < 1e-3) + mask_pos = torch.where(eigvol > 0) + mask_neg = torch.where(eigvol < 0) + + eigvol_pos = torch.zeros_like(eigvol) + eigvol_neg = torch.zeros_like(eigvol) + + eigvol_pos[mask_pos] = 1.0 + eigvol_neg[mask_neg] = -1.0 + + eigvol_for_img = eigvol_neg + eigvol_pos + eigvol_for_img[mask_small] = 0.0 + + ax.flatten()[i].imshow( + eigvol_for_img.sum(0), cmap="coolwarm", label=f"Eigenvector {i}" + ) + ax.flatten()[i].set_title(f"Eigenvector {i}") + ax.flatten()[i].axis("off") + i_max = i + + if i_max < n_cols * n_rows: + for j in range(i_max + 1, n_cols * n_rows): + ax.flatten()[j].axis("off") + + plt.subplots_adjust(wspace=0.0) + + # add a colorbar for the whole figure + fig.colorbar( + ax.flatten()[i].imshow(eigvol_for_img.sum(0), cmap="coolwarm"), + ax=ax, + orientation="horizontal", + label="Eigenvector value (neg or pos)", + ) + + fig.suptitle(title, fontsize=16) + + if save_path is not None: + plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1) return diff --git a/src/cryo_challenge/_svd/svd_utils.py b/src/cryo_challenge/_svd/svd_utils.py index 91d7aea..79c9a39 100644 --- a/src/cryo_challenge/_svd/svd_utils.py +++ b/src/cryo_challenge/_svd/svd_utils.py @@ -29,6 +29,24 @@ def sort_matrix_using_gt(dist_matrix: torch.Tensor, labels: np.ndarray): return dist_matrix, labels +def sort_matrix(dist_matrix, labels): + dist_matrix = dist_matrix.clone() + labels = labels.copy() + + # Sort by sum of rows + row_sum = torch.sum(dist_matrix, dim=0) + sort_idx = torch.argsort(row_sum, descending=True) + dist_matrix = dist_matrix[:, sort_idx][sort_idx] + labels = labels[sort_idx.numpy()] + + # Sort the first row + sort_idx = torch.argsort(dist_matrix[:, 0], descending=True) + dist_matrix = dist_matrix[:, sort_idx][sort_idx] + labels = labels[sort_idx.numpy()] + + return dist_matrix, labels + + def compute_distance_matrix(submissions_data, gt_data=None): n_subs = len(list(submissions_data.keys())) labels = list(submissions_data.keys()) @@ -64,7 +82,10 @@ def compute_distance_matrix(submissions_data, gt_data=None): dist_matrix, labels = sort_matrix_using_gt(dist_matrix, labels) - labels = np.array(labels) + else: + labels = np.array(labels) + dist_matrix, labels = sort_matrix(dist_matrix, labels) + results = {"dist_matrix": dist_matrix, "labels": labels} return results @@ -101,6 +122,7 @@ def compute_common_embedding(submissions_data, gt_data=None): results = { "common_embedding": embeddings, "singular_values": S, + "common_eigenvectors": V, } if gt_data is not None: diff --git a/src/cryo_challenge/data/_io/svd_io_utils.py b/src/cryo_challenge/data/_io/svd_io_utils.py index bc4a12f..610e3ef 100644 --- a/src/cryo_challenge/data/_io/svd_io_utils.py +++ b/src/cryo_challenge/data/_io/svd_io_utils.py @@ -59,6 +59,9 @@ def load_submissions_svd( label = submission["id"] populations = submission["populations"] + if not isinstance(populations, torch.Tensor): + populations = torch.tensor(populations) + volumes = submission["volumes"] if config["normalize_params"]["mask_path"] is not None: volumes = volumes * mask @@ -101,7 +104,7 @@ def load_submissions_svd( ) submissions_data[label] = { - "populations": torch.tensor(populations / populations.sum()), + "populations": populations / populations.sum(), "u_matrices": u_matrices.clone(), "singular_values": singular_values.clone(), "eigenvectors": eigenvectors.clone(),