diff --git a/src/cryo_challenge/_svd/svd_plots.py b/src/cryo_challenge/_svd/svd_plots.py index e4ec24e..1642204 100644 --- a/src/cryo_challenge/_svd/svd_plots.py +++ b/src/cryo_challenge/_svd/svd_plots.py @@ -475,7 +475,9 @@ def compute_gt_dist(z): populations_ref = submissions_data[label_ref]["populations"] embedding_ref = gt_embedding_results["submission_embedding"][label_ref] - n_cols = 4 + labels.pop(labels.index(label_ref)) + + n_cols = 3 if n_cols > len(labels): n_cols = len(labels) @@ -491,15 +493,15 @@ def compute_gt_dist(z): for i in range(len(labels)): label = labels[i] embedding = gt_embedding_results["submission_embedding"][label] - ax.flatten()[i].text( - 0.05, - 0.95, - str(i + 1), - fontsize=12, - transform=ax.flatten()[i].transAxes, - verticalalignment="top", - bbox=dict(facecolor="white", alpha=0.5), - ) + # ax.flatten()[i].text( + # 0.05, + # 0.95, + # str(i + 1), + # fontsize=12, + # transform=ax.flatten()[i].transAxes, + # verticalalignment="top", + # bbox=dict(facecolor="white", alpha=0.5), + # ) # ax.flatten()[i].bar( # edges[:-1],