Skip to content

Commit

Permalink
fix tests not passing and fix plots
Browse files Browse the repository at this point in the history
  • Loading branch information
DSilva27 committed Sep 5, 2024
1 parent b6f899d commit 340620c
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 35 deletions.
2 changes: 1 addition & 1 deletion src/cryo_challenge/_commands/run_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def main(args):
with open(args.config, "r") as file:
config = yaml.safe_load(file)

config = SVDConfig(**config).dict()
config = SVDConfig(**config).model_dump()

warnexists(config["output_params"]["output_file"])
mkbasedir(os.path.dirname(config["output_params"]["output_file"]))
Expand Down
64 changes: 52 additions & 12 deletions src/cryo_challenge/_svd/svd_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
compute_common_embedding,
project_to_gt_embedding,
)
from .svd_plots import plot_distance_matrix, plot_common_embedding, plot_gt_embedding
from .svd_plots import (
plot_distance_matrix,
plot_common_embedding,
plot_gt_embedding,
plot_common_eigenvectors,
)
from ..data._io.svd_io_utils import load_submissions_svd, load_gt_svd


Expand All @@ -33,25 +38,43 @@ def run_svd_with_ref(config: dict):
torch.save(results, config["output_params"]["output_file"])

if config["output_params"]["generate_plots"]:
outputs_fname_nopath_noext = os.path.basename(
config["output_params"]["output_file"]
)
outputs_fname_nopath_noext = os.path.splitext(outputs_fname_nopath_noext)[0]
path_plots = os.path.join(outputs_path, f"plots_{outputs_fname_nopath_noext}")

os.makedirs(path_plots, exist_ok=True)

print("Plotting distance matrix")
plot_distance_matrix(
dist_mtx_results["dist_matrix"],
dist_mtx_results["labels"],
"SVD Distance Matrix",
save_path=os.path.join(outputs_path, "svd_distance_matrix.png"),
title="SVD Distance Matrix",
save_path=os.path.join(path_plots, "svd_distance_matrix.png"),
)

print("Plotting common embedding")
plot_common_embedding(
submissions_data,
common_embedding_results,
"Common Embedding between submissions",
save_path=os.path.join(outputs_path, "common_embedding.png"),
title="Common Embedding between submissions",
save_path=os.path.join(path_plots, "common_embedding.png"),
)

print("Plotting gt embedding")
plot_gt_embedding(
submissions_data,
gt_embedding_results,
"",
save_path=os.path.join(outputs_path, "gt_embedding.png"),
title="",
save_path=os.path.join(path_plots, "gt_embedding.png"),
)

print("Plotting common eigenvectors")
plot_common_eigenvectors(
common_embedding_results["common_eigenvectors"],
title="Common Eigenvectors between submissions",
save_path=os.path.join(path_plots, "common_eigenvectors.png"),
)

return
Expand All @@ -62,11 +85,11 @@ def run_svd_noref(config: dict):

submissions_data = load_submissions_svd(config)
dist_mtx_results = compute_distance_matrix(submissions_data)
common_embeddings_results = compute_common_embedding(submissions_data)
common_embedding_results = compute_common_embedding(submissions_data)

results = {
"distance_matrix_results": dist_mtx_results,
"common_embedding_results": common_embeddings_results,
"common_embedding_results": common_embedding_results,
}

if config["output_params"]["save_svd_data"]:
Expand All @@ -75,18 +98,35 @@ def run_svd_noref(config: dict):
torch.save(results, config["output_params"]["output_file"])

if config["output_params"]["generate_plots"]:
outputs_fname_nopath_noext = os.path.basename(
config["output_params"]["output_file"]
)
outputs_fname_nopath_noext = os.path.splitext(outputs_fname_nopath_noext)[0]
path_plots = os.path.join(outputs_path, f"plots_{outputs_fname_nopath_noext}")
os.makedirs(path_plots, exist_ok=True)

print("Plotting distance matrix")

plot_distance_matrix(
dist_mtx_results["dist_matrix"],
dist_mtx_results["labels"],
"SVD Distance Matrix",
save_path=os.path.join(outputs_path, "svd_distance_matrix.png"),
save_path=os.path.join(path_plots, "svd_distance_matrix.png"),
)

print("Plotting common embedding")
plot_common_embedding(
submissions_data,
common_embeddings_results,
common_embedding_results,
"Common Embedding between submissions",
save_path=os.path.join(outputs_path, "common_embedding.png"),
save_path=os.path.join(path_plots, "common_embedding.png"),
)

print("Plotting common eigenvectors")
plot_common_eigenvectors(
common_embedding_results["common_eigenvectors"],
title="Common Eigenvectors between submissions",
save_path=os.path.join(path_plots, "common_eigenvectors.png"),
)

return
100 changes: 80 additions & 20 deletions src/cryo_challenge/_svd/svd_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
# "Rocky Road": {"color": "#4daf4a", "marker": "*"},
# }

COLORS = sns.color_palette("Set3", 12)
COLORS = [color for color in COLORS.as_hex()]

MARKERS = ["o", "v", "^", "<", ">", "D", "x", "*", "s", "p", "P", "*"]
MARKERS = ["o", "v", "^", "<", ">", "D", "X", "*", "s", "p", "P", "*", "h", "H"]
LABELS = [
"Ground Truth",
"Cookie Dough",
Expand All @@ -35,17 +32,24 @@
"Salted Caramel",
"Chocolate Chip",
"Rocky Road",
"Pina Colada",
"Bubble Gum",
]

assert len(MARKERS) >= len(LABELS)

COLORS = sns.color_palette("Set3", len(LABELS))
COLORS = [color for color in COLORS.as_hex()]

PLOT_SETUP = {}

for i in range(12):
for i in range(len(LABELS)):
PLOT_SETUP[LABELS[i]] = {
"color": COLORS[i],
"marker": MARKERS[i],
}

PLOT_SETUP["gt_left"] = {"color": "#e41a1c", "marker": "o"}
PLOT_SETUP["Coffee"] = {"color": "#e41a1c", "marker": "o"}
PLOT_SETUP["gt_right"] = {"color": "#377eb8", "marker": "v"}


Expand All @@ -58,8 +62,8 @@ def plot_distance_matrix(dist_matrix, labels, title="", save_path=None):

ax.set_title(title)
if save_path is not None:
plt.savefig(save_path)
plt.show()
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1)

return


Expand All @@ -77,7 +81,8 @@ def plot_common_embedding(
plot_setup = {}
for i, label in enumerate(labels):
for possible_label in PLOT_SETUP.keys():
if label in possible_label:
# print(label, possible_label)
if possible_label in label:
plot_setup[label] = PLOT_SETUP[possible_label]

for label in labels:
Expand Down Expand Up @@ -106,19 +111,17 @@ def plot_common_embedding(
n_rows = np.ceil(n_rows).astype(int)
n_cols = np.ceil(len(labels) / n_rows).astype(int)

print(n_cols, n_rows)
fig, ax = plt.subplots(
n_rows, n_cols, figsize=(n_cols * 5, n_rows * 3), sharex=True, sharey=True
)
if n_rows == 1 and n_cols == 1:
ax = np.array([ax])

print(all_embeddings.shape)
for i in range(len(labels)):
sns.kdeplot(
x=all_embeddings[:, pc1],
y=all_embeddings[:, pc2],
cmap="viridis",
cmap="gray",
fill=True,
cbar=False,
ax=ax.flatten()[i],
Expand All @@ -130,7 +133,7 @@ def plot_common_embedding(
sns.kdeplot(
x=all_embeddings[:, pc1],
y=all_embeddings[:, pc2],
cmap="viridis",
cmap="gray",
fill=True,
cbar=False,
ax=ax.flatten()[len(labels)],
Expand Down Expand Up @@ -163,8 +166,8 @@ def plot_common_embedding(
label=str(i + 1) + ". " + labels[i],
)

ax.flatten()[i].set_xticks([])
ax.flatten()[i].set_yticks([])
# ax.flatten()[i].set_xticks([])
# ax.flatten()[i].set_yticks([])

if i >= n_rows:
ax.flatten()[i].set_xlabel(f"Z{pc1 + 1}", fontsize=12)
Expand Down Expand Up @@ -215,7 +218,7 @@ def plot_common_embedding(
)

if save_path is not None:
plt.savefig(save_path)
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1)

return

Expand All @@ -237,19 +240,21 @@ def compute_gt_dist(z):
fig, ax = plt.subplots(
n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3), sharex=True, sharey=True
)
if n_rows == 1 and n_cols == 1:
ax = np.array([ax])

plot_setup = {}
for i, label in enumerate(submissions_data.keys()):
for possible_label in PLOT_SETUP.keys():
if label in possible_label:
if possible_label in label:
plot_setup[label] = PLOT_SETUP[possible_label]

for label in submissions_data.keys():
if label not in plot_setup.keys():
raise ValueError(f"Label {label} not found in PLOT_SETUP")

low_gt = -227.927103122416
high_gt = 214.014930744738
low_gt = -231.62100638454024
high_gt = 243.32448171011487
Z = np.linspace(low_gt, high_gt, gt_embedding_results["gt_embedding"].shape[0])
x_axis = np.linspace(
torch.min(gt_embedding_results["gt_embedding"][:, 0]),
Expand Down Expand Up @@ -327,6 +332,61 @@ def compute_gt_dist(z):
)

if save_path is not None:
plt.savefig(save_path)
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1)

return


def plot_common_eigenvectors(
common_eigenvectors, n_eig_to_plot=None, title="", save_path=None
):
n_eig_to_plot = min(10, len(common_eigenvectors))
n_cols = 5
n_rows = int(np.ceil(n_eig_to_plot / n_cols))

fig, ax = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 5))

box_size = int(round((common_eigenvectors[0].shape[-1]) ** (1 / 3)))
for i in range(n_eig_to_plot):
eigvol = common_eigenvectors[i].reshape(box_size, box_size, box_size)

mask_small = torch.where(torch.abs(eigvol) < 1e-3)
mask_pos = torch.where(eigvol > 0)
mask_neg = torch.where(eigvol < 0)

eigvol_pos = torch.zeros_like(eigvol)
eigvol_neg = torch.zeros_like(eigvol)

eigvol_pos[mask_pos] = 1.0
eigvol_neg[mask_neg] = -1.0

eigvol_for_img = eigvol_neg + eigvol_pos
eigvol_for_img[mask_small] = 0.0

ax.flatten()[i].imshow(
eigvol_for_img.sum(0), cmap="coolwarm", label=f"Eigenvector {i}"
)
ax.flatten()[i].set_title(f"Eigenvector {i}")
ax.flatten()[i].axis("off")
i_max = i

if i_max < n_cols * n_rows:
for j in range(i_max + 1, n_cols * n_rows):
ax.flatten()[j].axis("off")

plt.subplots_adjust(wspace=0.0)

# add a colorbar for the whole figure
fig.colorbar(
ax.flatten()[i].imshow(eigvol_for_img.sum(0), cmap="coolwarm"),
ax=ax,
orientation="horizontal",
label="Eigenvector value (neg or pos)",
)

fig.suptitle(title, fontsize=16)

if save_path is not None:
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1)

return
24 changes: 23 additions & 1 deletion src/cryo_challenge/_svd/svd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,24 @@ def sort_matrix_using_gt(dist_matrix: torch.Tensor, labels: np.ndarray):
return dist_matrix, labels


def sort_matrix(dist_matrix, labels):
dist_matrix = dist_matrix.clone()
labels = labels.copy()

# Sort by sum of rows
row_sum = torch.sum(dist_matrix, dim=0)
sort_idx = torch.argsort(row_sum, descending=True)
dist_matrix = dist_matrix[:, sort_idx][sort_idx]
labels = labels[sort_idx.numpy()]

# Sort the first row
sort_idx = torch.argsort(dist_matrix[:, 0], descending=True)
dist_matrix = dist_matrix[:, sort_idx][sort_idx]
labels = labels[sort_idx.numpy()]

return dist_matrix, labels


def compute_distance_matrix(submissions_data, gt_data=None):
n_subs = len(list(submissions_data.keys()))
labels = list(submissions_data.keys())
Expand Down Expand Up @@ -64,7 +82,10 @@ def compute_distance_matrix(submissions_data, gt_data=None):

dist_matrix, labels = sort_matrix_using_gt(dist_matrix, labels)

labels = np.array(labels)
else:
labels = np.array(labels)
dist_matrix, labels = sort_matrix(dist_matrix, labels)

results = {"dist_matrix": dist_matrix, "labels": labels}
return results

Expand Down Expand Up @@ -101,6 +122,7 @@ def compute_common_embedding(submissions_data, gt_data=None):
results = {
"common_embedding": embeddings,
"singular_values": S,
"common_eigenvectors": V,
}

if gt_data is not None:
Expand Down
5 changes: 4 additions & 1 deletion src/cryo_challenge/data/_io/svd_io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def load_submissions_svd(
label = submission["id"]
populations = submission["populations"]

if not isinstance(populations, torch.Tensor):
populations = torch.tensor(populations)

volumes = submission["volumes"]
if config["normalize_params"]["mask_path"] is not None:
volumes = volumes * mask
Expand Down Expand Up @@ -101,7 +104,7 @@ def load_submissions_svd(
)

submissions_data[label] = {
"populations": torch.tensor(populations / populations.sum()),
"populations": populations / populations.sum(),
"u_matrices": u_matrices.clone(),
"singular_values": singular_values.clone(),
"eigenvectors": eigenvectors.clone(),
Expand Down

0 comments on commit 340620c

Please sign in to comment.