Skip to content

Commit

Permalink
fix issues with plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
DSilva27 committed Sep 3, 2024
1 parent 3beaaab commit b6f899d
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 29 deletions.
4 changes: 1 addition & 3 deletions src/cryo_challenge/_svd/svd_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ def run_svd_noref(config: dict):

submissions_data = load_submissions_svd(config)
dist_mtx_results = compute_distance_matrix(submissions_data)
common_embeddings_results = compute_common_embedding(submissions_data)[
"common_embedding"
]
common_embeddings_results = compute_common_embedding(submissions_data)

results = {
"distance_matrix_results": dist_mtx_results,
Expand Down
147 changes: 125 additions & 22 deletions src/cryo_challenge/_svd/svd_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,51 @@
import torch
import numpy as np

# PLOT_SETUP = {
# "Ground Truth": {"color": "#e41a1c", "marker": "o"},
# "Cookie Dough": {"color": "#377eb8", "marker": "v"},
# "Mango": {"color": "#4daf4a", "marker": "^"},
# "Vanilla": {"color": "#984ea3", "marker": "<"},
# "Peanut Butter": {"color": "#ff7f00", "marker": ">"},
# "Neapolitan": {"color": "#ffff33", "marker": "D"},
# "Chocolate": {"color": "#a65628", "marker": "x"},
# "Black Raspberry": {"color": "#f781bf", "marker": "*"},
# "Cherry": {"color": "#999999", "marker": "s"},
# "Salted Caramel": {"color": "#e41a1c", "marker": "p"},
# "Chocolate Chip": {"color": "#377eb8", "marker": "P"},
# "Rocky Road": {"color": "#4daf4a", "marker": "*"},
# }

COLORS = sns.color_palette("Set3", 12)
COLORS = [color for color in COLORS.as_hex()]

MARKERS = ["o", "v", "^", "<", ">", "D", "x", "*", "s", "p", "P", "*"]
LABELS = [
"Ground Truth",
"Cookie Dough",
"Mango",
"Vanilla",
"Peanut Butter",
"Neapolitan",
"Chocolate",
"Black Raspberry",
"Cherry",
"Salted Caramel",
"Chocolate Chip",
"Rocky Road",
]

PLOT_SETUP = {}

for i in range(12):
PLOT_SETUP[LABELS[i]] = {
"color": COLORS[i],
"marker": MARKERS[i],
}

PLOT_SETUP["gt_left"] = {"color": "#e41a1c", "marker": "o"}
PLOT_SETUP["gt_right"] = {"color": "#377eb8", "marker": "v"}


def plot_distance_matrix(dist_matrix, labels, title="", save_path=None):
fig, ax = plt.subplots()
Expand All @@ -29,6 +74,19 @@ def plot_common_embedding(
all_embeddings.append(embedding)
labels.append(label)

plot_setup = {}
for i, label in enumerate(labels):
for possible_label in PLOT_SETUP.keys():
if label in possible_label:
plot_setup[label] = PLOT_SETUP[possible_label]

for label in labels:
if label not in plot_setup.keys():
raise ValueError(f"Label {label} not found in PLOT_SETUP")

if "gt_embedding" in embedding_results:
plot_setup["Ground Truth"] = PLOT_SETUP["Ground Truth"]

all_embeddings = torch.cat(all_embeddings, dim=0)

weights = []
Expand All @@ -39,19 +97,23 @@ def plot_common_embedding(
weights = weights / weights.sum()

if "gt_embedding" in embedding_results:
n_cols = min(3, len(labels) + 1)
n_rows = min((len(labels) + 1) // n_cols, 1)
n_rows = np.sqrt(len(labels) + 1)
n_rows = np.ceil(n_rows).astype(int)
n_cols = np.ceil((len(labels) + 1) / n_rows).astype(int)

else:
n_cols = min(3, len(labels))
n_rows = min(len(labels) // n_cols, 1)
n_rows = np.sqrt(len(labels))
n_rows = np.ceil(n_rows).astype(int)
n_cols = np.ceil(len(labels) / n_rows).astype(int)

print(n_cols, n_rows)
fig, ax = plt.subplots(
n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3), sharex=True, sharey=True
n_rows, n_cols, figsize=(n_cols * 5, n_rows * 3), sharex=True, sharey=True
)
if n_rows == 1 and n_cols == 1:
ax = np.array([ax])

print(all_embeddings.shape)
for i in range(len(labels)):
sns.kdeplot(
x=all_embeddings[:, pc1],
Expand All @@ -61,6 +123,7 @@ def plot_common_embedding(
cbar=False,
ax=ax.flatten()[i],
weights=weights,
alpha=0.8,
)

if "gt_embedding" in embedding_results:
Expand All @@ -72,21 +135,32 @@ def plot_common_embedding(
cbar=False,
ax=ax.flatten()[len(labels)],
weights=weights,
# alpha=0.5,
)

for i in range(len(labels)):
pops = submissions_data[labels[i]]["populations"].numpy()
pops = pops / pops.sum()

# put a value of i in the top left corner of each plot
ax.flatten()[i].text(
0.05,
0.95,
str(i + 1),
fontsize=12,
transform=ax.flatten()[i].transAxes,
verticalalignment="top",
bbox=dict(facecolor="white", alpha=0.5),
)
ax.flatten()[i].scatter(
x=embedding_results["common_embedding"][labels[i]][:, pc1],
y=embedding_results["common_embedding"][labels[i]][:, pc2],
color="red",
color=plot_setup[labels[i]]["color"],
s=pops / pops.max() * 200,
marker="o",
marker=plot_setup[labels[i]]["marker"],
linewidth=0.3,
edgecolor="white",
label=labels[i],
edgecolor="black",
label=str(i + 1) + ". " + labels[i],
)

ax.flatten()[i].set_xticks([])
Expand All @@ -101,16 +175,24 @@ def plot_common_embedding(

if "gt_embedding" in embedding_results:
i_max += 1

ax.flatten()[i_max].text(
0.05,
0.95,
str(i_max + 1),
fontsize=12,
transform=ax.flatten()[i_max].transAxes,
verticalalignment="top",
bbox=dict(facecolor="white", alpha=0.5),
)
ax.flatten()[i_max].scatter(
x=embedding_results["gt_embedding"][:, pc1],
y=embedding_results["gt_embedding"][:, pc2],
color="red",
color=plot_setup["Ground Truth"]["color"],
s=100,
marker="o",
marker=plot_setup["Ground Truth"]["marker"],
linewidth=0.3,
edgecolor="white",
label="Ground Truth",
edgecolor="black",
label=f"{i_max + 1}. Ground Truth",
)

ax.flatten()[i_max].set_xlabel(f"Z{pc1 + 1}", fontsize=12)
Expand Down Expand Up @@ -148,13 +230,24 @@ def compute_gt_dist(z):
gauss3 = gauss_pdf(z, -150, 750)
return gauss1 + gauss2 + gauss3

n_cols = 3
n_rows = len(list(submissions_data.keys())) // n_cols + 1
n_rows = np.sqrt(len(list(submissions_data.keys())))
n_rows = np.ceil(n_rows).astype(int)
n_cols = np.ceil(len(list(submissions_data.keys())) / n_rows).astype(int)

fig, ax = plt.subplots(
n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3), sharex=True, sharey=True
)

plot_setup = {}
for i, label in enumerate(submissions_data.keys()):
for possible_label in PLOT_SETUP.keys():
if label in possible_label:
plot_setup[label] = PLOT_SETUP[possible_label]

for label in submissions_data.keys():
if label not in plot_setup.keys():
raise ValueError(f"Label {label} not found in PLOT_SETUP")

low_gt = -227.927103122416
high_gt = 214.014930744738
Z = np.linspace(low_gt, high_gt, gt_embedding_results["gt_embedding"].shape[0])
Expand All @@ -171,6 +264,16 @@ def compute_gt_dist(z):

i = 0
for label, embedding in gt_embedding_results["submission_embedding"].items():
ax.flatten()[i].text(
0.05,
0.95,
str(i + 1),
fontsize=12,
transform=ax.flatten()[i].transAxes,
verticalalignment="top",
bbox=dict(facecolor="white", alpha=0.5),
)

ax.flatten()[i].bar(
edges[:-1],
frq / frq.max(),
Expand All @@ -186,12 +289,12 @@ def compute_gt_dist(z):
ax.flatten()[i].scatter(
x=embedding[:, 0],
y=populations / populations.max(),
color="red",
marker="o",
s=60,
color=plot_setup[label]["color"],
marker=plot_setup[label]["marker"],
s=100,
linewidth=0.3,
edgecolor="white",
label=label,
edgecolor="black",
label=f"{i+1}. {label}",
)

# set x label only for the last row
Expand All @@ -205,7 +308,7 @@ def compute_gt_dist(z):
ax.flatten()[i].set_ylim(0.0, 1.1)
ax.flatten()[i].set_xlim(x_axis[0] * 1.3, x_axis[-1] * 1.3)
# set ticks to be maximum 5 ticks
ax.flatten()[i].set_yticks(np.arange(0, 1.25, 0.25))
ax.flatten()[i].set_yticks(np.arange(0.25, 1.25, 0.25))
ax.flatten()[i].set_xticks([])

plt.subplots_adjust(wspace=0.0, hspace=0.0)
Expand Down
9 changes: 5 additions & 4 deletions src/cryo_challenge/_svd/svd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,22 +84,23 @@ def compute_common_embedding(submissions_data, gt_data=None):
for i, label in enumerate(labels):
eigenvectors[i * shape_per_sub[0] : (i + 1) * shape_per_sub[0], :] = (
submissions_data[label]["eigenvectors"].T
)
) * submissions_data[label]["singular_values"][:, None]

U, S, V = torch.linalg.svd(eigenvectors, full_matrices=False)

Z_common = (U @ torch.diag(S)).reshape(n_subs, shape_per_sub[0], -1)
embeddings = {}

for i, label in enumerate(labels):
Z_i = submissions_data[label]["u_matrices"] @ torch.diag(
submissions_data[label]["singular_values"]
)
Z_i = submissions_data[label]["u_matrices"] # @ torch.diag(
# submissions_data[label]["singular_values"]
# )
Z_i_common = torch.einsum("ij, jk -> ik", Z_i, Z_common[i])
embeddings[labels[i]] = Z_i_common

results = {
"common_embedding": embeddings,
"singular_values": S,
}

if gt_data is not None:
Expand Down

0 comments on commit b6f899d

Please sign in to comment.