diff --git a/src/cryo_challenge/_svd/svd_pipeline.py b/src/cryo_challenge/_svd/svd_pipeline.py index 7c9d840..50ba226 100644 --- a/src/cryo_challenge/_svd/svd_pipeline.py +++ b/src/cryo_challenge/_svd/svd_pipeline.py @@ -62,9 +62,7 @@ def run_svd_noref(config: dict): submissions_data = load_submissions_svd(config) dist_mtx_results = compute_distance_matrix(submissions_data) - common_embeddings_results = compute_common_embedding(submissions_data)[ - "common_embedding" - ] + common_embeddings_results = compute_common_embedding(submissions_data) results = { "distance_matrix_results": dist_mtx_results, diff --git a/src/cryo_challenge/_svd/svd_plots.py b/src/cryo_challenge/_svd/svd_plots.py index 47f44f1..5cc7770 100644 --- a/src/cryo_challenge/_svd/svd_plots.py +++ b/src/cryo_challenge/_svd/svd_plots.py @@ -3,6 +3,51 @@ import torch import numpy as np +# PLOT_SETUP = { +# "Ground Truth": {"color": "#e41a1c", "marker": "o"}, +# "Cookie Dough": {"color": "#377eb8", "marker": "v"}, +# "Mango": {"color": "#4daf4a", "marker": "^"}, +# "Vanilla": {"color": "#984ea3", "marker": "<"}, +# "Peanut Butter": {"color": "#ff7f00", "marker": ">"}, +# "Neapolitan": {"color": "#ffff33", "marker": "D"}, +# "Chocolate": {"color": "#a65628", "marker": "x"}, +# "Black Raspberry": {"color": "#f781bf", "marker": "*"}, +# "Cherry": {"color": "#999999", "marker": "s"}, +# "Salted Caramel": {"color": "#e41a1c", "marker": "p"}, +# "Chocolate Chip": {"color": "#377eb8", "marker": "P"}, +# "Rocky Road": {"color": "#4daf4a", "marker": "*"}, +# } + +COLORS = sns.color_palette("Set3", 12) +COLORS = [color for color in COLORS.as_hex()] + +MARKERS = ["o", "v", "^", "<", ">", "D", "x", "*", "s", "p", "P", "*"] +LABELS = [ + "Ground Truth", + "Cookie Dough", + "Mango", + "Vanilla", + "Peanut Butter", + "Neapolitan", + "Chocolate", + "Black Raspberry", + "Cherry", + "Salted Caramel", + "Chocolate Chip", + "Rocky Road", +] + +PLOT_SETUP = {} + +for i in range(12): + PLOT_SETUP[LABELS[i]] = { + "color": COLORS[i], + "marker": MARKERS[i], + } + +PLOT_SETUP["gt_left"] = {"color": "#e41a1c", "marker": "o"} +PLOT_SETUP["gt_right"] = {"color": "#377eb8", "marker": "v"} + def plot_distance_matrix(dist_matrix, labels, title="", save_path=None): fig, ax = plt.subplots() @@ -29,6 +74,19 @@ def plot_common_embedding( all_embeddings.append(embedding) labels.append(label) + plot_setup = {} + for i, label in enumerate(labels): + for possible_label in PLOT_SETUP.keys(): + if label in possible_label: + plot_setup[label] = PLOT_SETUP[possible_label] + + for label in labels: + if label not in plot_setup.keys(): + raise ValueError(f"Label {label} not found in PLOT_SETUP") + + if "gt_embedding" in embedding_results: + plot_setup["Ground Truth"] = PLOT_SETUP["Ground Truth"] + all_embeddings = torch.cat(all_embeddings, dim=0) weights = [] @@ -39,19 +97,23 @@ def plot_common_embedding( weights = weights / weights.sum() if "gt_embedding" in embedding_results: - n_cols = min(3, len(labels) + 1) - n_rows = min((len(labels) + 1) // n_cols, 1) + n_rows = np.sqrt(len(labels) + 1) + n_rows = np.ceil(n_rows).astype(int) + n_cols = np.ceil((len(labels) + 1) / n_rows).astype(int) else: - n_cols = min(3, len(labels)) - n_rows = min(len(labels) // n_cols, 1) + n_rows = np.sqrt(len(labels)) + n_rows = np.ceil(n_rows).astype(int) + n_cols = np.ceil(len(labels) / n_rows).astype(int) + print(n_cols, n_rows) fig, ax = plt.subplots( - n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3), sharex=True, sharey=True + n_rows, n_cols, figsize=(n_cols * 5, n_rows * 3), sharex=True, sharey=True ) if n_rows == 1 and n_cols == 1: ax = np.array([ax]) + print(all_embeddings.shape) for i in range(len(labels)): sns.kdeplot( x=all_embeddings[:, pc1], @@ -61,6 +123,7 @@ def plot_common_embedding( cbar=False, ax=ax.flatten()[i], weights=weights, + alpha=0.8, ) if "gt_embedding" in embedding_results: @@ -72,21 +135,32 @@ def plot_common_embedding( cbar=False, ax=ax.flatten()[len(labels)], weights=weights, + # alpha=0.5, ) for i in range(len(labels)): pops = submissions_data[labels[i]]["populations"].numpy() pops = pops / pops.sum() + # put a value of i in the top left corner of each plot + ax.flatten()[i].text( + 0.05, + 0.95, + str(i + 1), + fontsize=12, + transform=ax.flatten()[i].transAxes, + verticalalignment="top", + bbox=dict(facecolor="white", alpha=0.5), + ) ax.flatten()[i].scatter( x=embedding_results["common_embedding"][labels[i]][:, pc1], y=embedding_results["common_embedding"][labels[i]][:, pc2], - color="red", + color=plot_setup[labels[i]]["color"], s=pops / pops.max() * 200, - marker="o", + marker=plot_setup[labels[i]]["marker"], linewidth=0.3, - edgecolor="white", - label=labels[i], + edgecolor="black", + label=str(i + 1) + ". " + labels[i], ) ax.flatten()[i].set_xticks([]) @@ -101,16 +175,24 @@ def plot_common_embedding( if "gt_embedding" in embedding_results: i_max += 1 - + ax.flatten()[i_max].text( + 0.05, + 0.95, + str(i_max + 1), + fontsize=12, + transform=ax.flatten()[i_max].transAxes, + verticalalignment="top", + bbox=dict(facecolor="white", alpha=0.5), + ) ax.flatten()[i_max].scatter( x=embedding_results["gt_embedding"][:, pc1], y=embedding_results["gt_embedding"][:, pc2], - color="red", + color=plot_setup["Ground Truth"]["color"], s=100, - marker="o", + marker=plot_setup["Ground Truth"]["marker"], linewidth=0.3, - edgecolor="white", - label="Ground Truth", + edgecolor="black", + label=f"{i_max + 1}. Ground Truth", ) ax.flatten()[i_max].set_xlabel(f"Z{pc1 + 1}", fontsize=12) @@ -148,13 +230,24 @@ def compute_gt_dist(z): gauss3 = gauss_pdf(z, -150, 750) return gauss1 + gauss2 + gauss3 - n_cols = 3 - n_rows = len(list(submissions_data.keys())) // n_cols + 1 + n_rows = np.sqrt(len(list(submissions_data.keys()))) + n_rows = np.ceil(n_rows).astype(int) + n_cols = np.ceil(len(list(submissions_data.keys())) / n_rows).astype(int) fig, ax = plt.subplots( n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3), sharex=True, sharey=True ) + plot_setup = {} + for i, label in enumerate(submissions_data.keys()): + for possible_label in PLOT_SETUP.keys(): + if label in possible_label: + plot_setup[label] = PLOT_SETUP[possible_label] + + for label in submissions_data.keys(): + if label not in plot_setup.keys(): + raise ValueError(f"Label {label} not found in PLOT_SETUP") + low_gt = -227.927103122416 high_gt = 214.014930744738 Z = np.linspace(low_gt, high_gt, gt_embedding_results["gt_embedding"].shape[0]) @@ -171,6 +264,16 @@ def compute_gt_dist(z): i = 0 for label, embedding in gt_embedding_results["submission_embedding"].items(): + ax.flatten()[i].text( + 0.05, + 0.95, + str(i + 1), + fontsize=12, + transform=ax.flatten()[i].transAxes, + verticalalignment="top", + bbox=dict(facecolor="white", alpha=0.5), + ) + ax.flatten()[i].bar( edges[:-1], frq / frq.max(), @@ -186,12 +289,12 @@ def compute_gt_dist(z): ax.flatten()[i].scatter( x=embedding[:, 0], y=populations / populations.max(), - color="red", - marker="o", - s=60, + color=plot_setup[label]["color"], + marker=plot_setup[label]["marker"], + s=100, linewidth=0.3, - edgecolor="white", - label=label, + edgecolor="black", + label=f"{i+1}. {label}", ) # set x label only for the last row @@ -205,7 +308,7 @@ def compute_gt_dist(z): ax.flatten()[i].set_ylim(0.0, 1.1) ax.flatten()[i].set_xlim(x_axis[0] * 1.3, x_axis[-1] * 1.3) # set ticks to be maximum 5 ticks - ax.flatten()[i].set_yticks(np.arange(0, 1.25, 0.25)) + ax.flatten()[i].set_yticks(np.arange(0.25, 1.25, 0.25)) ax.flatten()[i].set_xticks([]) plt.subplots_adjust(wspace=0.0, hspace=0.0) diff --git a/src/cryo_challenge/_svd/svd_utils.py b/src/cryo_challenge/_svd/svd_utils.py index 4cdaa48..91d7aea 100644 --- a/src/cryo_challenge/_svd/svd_utils.py +++ b/src/cryo_challenge/_svd/svd_utils.py @@ -84,7 +84,7 @@ def compute_common_embedding(submissions_data, gt_data=None): for i, label in enumerate(labels): eigenvectors[i * shape_per_sub[0] : (i + 1) * shape_per_sub[0], :] = ( submissions_data[label]["eigenvectors"].T - ) + ) * submissions_data[label]["singular_values"][:, None] U, S, V = torch.linalg.svd(eigenvectors, full_matrices=False) @@ -92,14 +92,15 @@ def compute_common_embedding(submissions_data, gt_data=None): embeddings = {} for i, label in enumerate(labels): - Z_i = submissions_data[label]["u_matrices"] @ torch.diag( - submissions_data[label]["singular_values"] - ) + Z_i = submissions_data[label]["u_matrices"] # @ torch.diag( + # submissions_data[label]["singular_values"] + # ) Z_i_common = torch.einsum("ij, jk -> ik", Z_i, Z_common[i]) embeddings[labels[i]] = Z_i_common results = { "common_embedding": embeddings, + "singular_values": S, } if gt_data is not None: