diff --git a/config_files/config_svd.yaml b/config_files/config_svd.yaml index 71e7618..d5e2da0 100644 --- a/config_files/config_svd.yaml +++ b/config_files/config_svd.yaml @@ -1,19 +1,21 @@ -path_to_volumes: /path/to/volumes -box_size_ds: 32 -submission_list: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] -experiment_mode: "all_vs_ref" # options are "all_vs_all", "all_vs_ref" +path_to_submissions: path/to/preprocessed/submissions/ # where all the submission_i.pt files are +#excluded_submissions: # you can exclude some submissions by filename, default = [] +# - "submission_0.pt" +# - "submission_1.pt" +voxel_size: 1.0 # voxel size of the input maps (will probably be removed soon) -power_spectrum_normalization: - ref_vol_key: "FLAVOR" # which submission should be used - ref_vol_index: 0 # which volume of that submission should be used +dtype: float32 # optional, default = float32 +svd_max_rank: 5 # optional, default = full rank svd +normalize_params: # optional, if not given there will be no normalization + mask_path: path/to/mask.mrc # default = None, no masking applied + bfactor: 170 # default = None, no bfactor applied + box_size_ds: 16 # default = None, no downsampling applied -# optional unless experiment_mode is "all_vs_ref" -path_to_reference: /path/to/reference/volumes.pt -dtype: "float32" # options are "float32", "float64" -output_options: - # path to file will be created if it does not exist - output_file: /path/to/output_file.pt - # whether or not to save the processed volumes (downsampled, normalized, etc.) - save_volumes: False - # whether or not to save the SVD matrices (U, S, V) - save_svd_matrices: False +gt_params: # optional, if provided there will be extra results + gt_vols_file: path/to/gt_volumes.npy # volumes must be in .npy format (memory stuff) + skip_vols: 1 # default = 1, no volumes skipped. Equivalent to volumes[::skip_vols] + +output_params: + output_file: path/to/output_file.pt # where the results will be saved + save_svd_data: True # optional, default = False + generate_plots: True # optional, default = False diff --git a/src/cryo_challenge/_commands/run_svd.py b/src/cryo_challenge/_commands/run_svd.py index a6d06be..3a73204 100644 --- a/src/cryo_challenge/_commands/run_svd.py +++ b/src/cryo_challenge/_commands/run_svd.py @@ -6,8 +6,8 @@ import os import yaml -from .._svd.svd_pipeline import run_all_vs_all_pipeline, run_all_vs_ref_pipeline -from ..data._validation.config_validators import validate_config_svd +from .._svd.svd_pipeline import run_svd_noref, run_svd_with_ref +from ..data._validation.config_validators import SVDConfig def add_args(parser): @@ -35,15 +35,21 @@ def main(args): with open(args.config, "r") as file: config = yaml.safe_load(file) - validate_config_svd(config) - warnexists(config["output_options"]["output_file"]) - mkbasedir(os.path.dirname(config["output_options"]["output_file"])) + config = SVDConfig(**config).dict() - if config["experiment_mode"] == "all_vs_all": - run_all_vs_all_pipeline(config) + warnexists(config["output_params"]["output_file"]) + mkbasedir(os.path.dirname(config["output_params"]["output_file"])) - elif config["experiment_mode"] == "all_vs_ref": - run_all_vs_ref_pipeline(config) + output_path = os.path.dirname(config["output_params"]["output_file"]) + + with open(os.path.join(output_path, "config.yaml"), "w") as file: + yaml.dump(config, file) + + if config["gt_params"] is None: + run_svd_noref(config) + + else: + run_svd_with_ref(config) return diff --git a/src/cryo_challenge/_svd/svd_pipeline.py b/src/cryo_challenge/_svd/svd_pipeline.py index 50a04d8..7c9d840 100644 --- a/src/cryo_challenge/_svd/svd_pipeline.py +++ b/src/cryo_challenge/_svd/svd_pipeline.py @@ -1,245 +1,94 @@ import torch -from typing import Tuple -import yaml -import argparse - -from .svd_utils import get_vols_svd, project_vols_to_svd -from ..data._io.svd_io_utils import ( - load_volumes, - load_ref_vols, - remove_mean_volumes, +import os + +from .svd_utils import ( + compute_distance_matrix, + compute_common_embedding, + project_to_gt_embedding, ) -from ..data._validation.config_validators import validate_config_svd - - -def run_svd_with_ref( - volumes: torch.tensor, ref_volumes: torch.tensor -) -> Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]: - """ - Compute the singular value decomposition of the reference volumes and project the input volumes onto the right singular vectors of the reference volumes. - - Parameters - ---------- - volumes: torch.tensor - Tensor of shape (n_volumes, n_x, n_y, n_z) containing the volumes to be projected. - ref_volumes: torch.tensor - Tensor of shape (n_volumes_ref, n_x, n_y, n_z) containing the reference volumes. - - Returns - ------- - U: torch.tensor - Left singular vectors of the reference volumes. - S: torch.tensor - Singular values of the reference volumes. - V: torch.tensor - Right singular vectors of the reference volumes. - coeffs: torch.tensor - Coefficients of the input volumes projected onto the right singular vectors of the reference volumes. - - Examples - -------- - >>> volumes = torch.randn(10, 32, 32, 32) - >>> ref_volumes = torch.randn(10, 32, 32, 32) - >>> U, S, V, coeffs = run_svd_with_ref(volumes, ref_volumes) - """ # noqa: E501 - - assert volumes.ndim == 4, "Input volumes must have shape (n_volumes, n_x, n_y, n_z)" - assert volumes.shape[0] > 0, "Input volumes must have at least one volume" - assert ( - ref_volumes.ndim == 4 - ), "Reference volumes must have shape (n_volumes, n_x, n_y, n_z)" - assert ref_volumes.shape[0] > 0, "Reference volumes must have at least one volume" - assert ( - volumes.shape[1:] == ref_volumes.shape[1:] - ), "Input volumes and reference volumes must have the same shape" - - U, S, V = get_vols_svd(ref_volumes) - coeffs = project_vols_to_svd(volumes, V) - coeffs_ref = U @ torch.diag(S) - - return U, S, V, coeffs, coeffs_ref - - -def run_svd_all_vs_all(volumes: torch.tensor): - """ - Compute the singular value decomposition of the input volumes. - - Parameters - ---------- - volumes: torch.tensor - Tensor of shape (n_volumes, n_x, n_y, n_z) containing the volumes to be decomposed. - - Returns - ------- - U: torch.tensor - Left singular vectors of the input volumes. - S: torch.tensor - Singular values of the input volumes. - V: torch.tensor - Right singular vectors of the input volumes. - coeffs: torch.tensor - Coefficients of the input volumes projected onto the right singular vectors. - - Examples - -------- - >>> volumes = torch.randn(10, 32, 32, 32) - >>> U, S, V, coeffs = run_svd_all_vs_all(volumes) - """ # noqa: E501 - U, S, V = get_vols_svd(volumes) - coeffs = U @ torch.diag(S) - return U, S, V, coeffs - - -def run_all_vs_all_pipeline(config: dict): - """ - Run the all-vs-all SVD pipeline. Load the volumes, compute the SVD, and save the results. - - Parameters - ---------- - config: dict - Dictionary containing the configuration options for the pipeline. - - The results are saved in a dictionary with the following - keys: - - coeffs: Coefficients of the input volumes projected onto the right singular vectors. - - metadata: Dictionary containing the populations and indices of each submission (see Tutorial for details). - - vols_per_submission: Dictionary containing the number of volumes per submission. - - if save_volumes set to True - - volumes: Tensor of shape (n_volumes, n_x, n_y, n_z) containing the volumes. - - mean_volumes: Tensor of shape (n_submissions, n_x, n_y, n_z) containing the mean volume of each submission. - * Note: the volumes will be downsampled, normalized and mean-removed - - if save_svd_matrices set to True - - U: Left singular vectors of the input volumes. - - S: Singular values of the input volumes. - - V: Right singular vectors of the input volumes. - - """ # noqa: E501 - - dtype = torch.float32 if config["dtype"] == "float32" else torch.float64 - volumes, metadata = load_volumes( - box_size_ds=config["box_size_ds"], - submission_list=config["submission_list"], - path_to_submissions=config["path_to_volumes"], - dtype=dtype, - ) - - volumes, mean_volumes = remove_mean_volumes(volumes, metadata) - - U, S, V, coeffs = run_svd_all_vs_all(volumes=volumes) - - output_dict = { - "coeffs": coeffs, - "metadata": metadata, - "config": config, - "sing_vals": S, - } +from .svd_plots import plot_distance_matrix, plot_common_embedding, plot_gt_embedding +from ..data._io.svd_io_utils import load_submissions_svd, load_gt_svd + + +def run_svd_with_ref(config: dict): + outputs_path = os.path.dirname(config["output_params"]["output_file"]) + + submissions_data = load_submissions_svd(config) + gt_data = load_gt_svd(config) - if config["output_options"]["save_volumes"]: - output_dict["volumes"] = volumes - output_dict["mean_volumes"] = mean_volumes - - if config["output_options"]["save_svd_matrices"]: - output_dict["U"] = U - output_dict["V"] = V - output_dict["S"] = S - - torch.save(output_dict, config["output_options"]["output_file"]) - - return output_dict - - -def run_all_vs_ref_pipeline(config: dict): - """ - Run the all-vs-ref SVD pipeline. Load the volumes, compute the SVD, and save the results. - - Parameters - ---------- - config: dict - Dictionary containing the configuration options for the pipeline. - - The results are saved in a dictionary with the following - keys: - - coeffs: Coefficients of the input volumes projected onto the right singular vectors. - - populations: Dictionary containing the populations of each submission. - - vols_per_submission: Dictionary containing the number of volumes per submission. - - if save_volumes set to True - - volumes: Tensor of shape (n_volumes, n_x, n_y, n_z) containing the volumes. - - mean_volumes: Tensor of shape (n_submissions, n_x, n_y, n_z) containing the mean volume of each submission. - - ref_volumes: Tensor of shape (n_volumes_ref, n_x, n_y, n_z) containing the reference volumes. - - mean_ref_volumes: Tensor of shape (n_x, n_y, n_z) containing the mean reference volume. - * Note: volumes and ref_volumes will be downsampled, normalized and mean-removed - - if save_svd_matrices set to True - - U: Left singular vectors of the reference volumes. - - S: Singular values of the reference volumes. - - V: Right singular vectors of the reference volumes. - - """ # noqa: E501 - - dtype = torch.float32 if config["dtype"] == "float32" else torch.float64 - - ref_volumes = load_ref_vols( - box_size_ds=config["box_size_ds"], - path_to_volumes=config["path_to_reference"], - dtype=dtype, - ) - - volumes, metadata = load_volumes( - box_size_ds=config["box_size_ds"], - submission_list=config["submission_list"], - path_to_submissions=config["path_to_volumes"], - dtype=dtype, - ) - - # Remove mean volumes - volumes, mean_volumes = remove_mean_volumes(volumes, metadata) - ref_volumes, mean_volume = remove_mean_volumes(ref_volumes) - - U, S, V, coeffs, coeffs_ref = run_svd_with_ref( - volumes=volumes, ref_volumes=ref_volumes - ) - - output_dict = { - "coeffs": coeffs, - "coeffs_ref": coeffs_ref, - "metadata": metadata, - "config": config, - "sing_vals": S, + dist_mtx_results = compute_distance_matrix(submissions_data, gt_data) + common_embedding_results = compute_common_embedding(submissions_data, gt_data) + gt_embedding_results = project_to_gt_embedding(submissions_data, gt_data) + + results = { + "distance_matrix_results": dist_mtx_results, + "common_embedding_results": common_embedding_results, + "gt_embedding_results": gt_embedding_results, } - if config["output_options"]["save_volumes"]: - output_dict["volumes"] = volumes - output_dict["mean_volumes"] = mean_volumes - output_dict["ref_volumes"] = ref_volumes - output_dict["mean_ref_volume"] = mean_volume + if config["output_params"]["save_svd_data"]: + results["submissions_data"] = submissions_data + results["gt_data"] = gt_data + + torch.save(results, config["output_params"]["output_file"]) + + if config["output_params"]["generate_plots"]: + plot_distance_matrix( + dist_mtx_results["dist_matrix"], + dist_mtx_results["labels"], + "SVD Distance Matrix", + save_path=os.path.join(outputs_path, "svd_distance_matrix.png"), + ) + + plot_common_embedding( + submissions_data, + common_embedding_results, + "Common Embedding between submissions", + save_path=os.path.join(outputs_path, "common_embedding.png"), + ) + + plot_gt_embedding( + submissions_data, + gt_embedding_results, + "", + save_path=os.path.join(outputs_path, "gt_embedding.png"), + ) - if config["output_options"]["save_svd_matrices"]: - output_dict["U"] = U - output_dict["S"] = S - output_dict["V"] = V + return - torch.save(output_dict, config["output_options"]["output_file"]) - return output_dict +def run_svd_noref(config: dict): + outputs_path = os.path.dirname(config["output_params"]["output_file"]) + submissions_data = load_submissions_svd(config) + dist_mtx_results = compute_distance_matrix(submissions_data) + common_embeddings_results = compute_common_embedding(submissions_data)[ + "common_embedding" + ] -def run_svd_pipeline(): - parser = argparse.ArgumentParser(description="Run SVD on volumes") - parser.add_argument( - "--config", type=str, default=None, help="Path to the config (yaml) file" - ) - args = parser.parse_args() + results = { + "distance_matrix_results": dist_mtx_results, + "common_embedding_results": common_embeddings_results, + } - with open(args.config, "r") as file: - config = yaml.safe_load(file) + if config["output_params"]["save_svd_data"]: + results["submissions_data"] = submissions_data - validate_config_svd(config) + torch.save(results, config["output_params"]["output_file"]) - return + if config["output_params"]["generate_plots"]: + plot_distance_matrix( + dist_mtx_results["dist_matrix"], + dist_mtx_results["labels"], + "SVD Distance Matrix", + save_path=os.path.join(outputs_path, "svd_distance_matrix.png"), + ) + plot_common_embedding( + submissions_data, + common_embeddings_results, + "Common Embedding between submissions", + save_path=os.path.join(outputs_path, "common_embedding.png"), + ) -if __name__ == "__main__": - run_svd_pipeline() + return diff --git a/src/cryo_challenge/_svd/svd_plots.py b/src/cryo_challenge/_svd/svd_plots.py new file mode 100644 index 0000000..47f44f1 --- /dev/null +++ b/src/cryo_challenge/_svd/svd_plots.py @@ -0,0 +1,229 @@ +import matplotlib.pyplot as plt +import seaborn as sns +import torch +import numpy as np + + +def plot_distance_matrix(dist_matrix, labels, title="", save_path=None): + fig, ax = plt.subplots() + cax = ax.matshow(dist_matrix, cmap="viridis") + fig.colorbar(cax) + ax.set_xticks(np.arange(len(labels)), labels, rotation=90) + ax.set_yticks(np.arange(len(labels)), labels) + + ax.set_title(title) + if save_path is not None: + plt.savefig(save_path) + plt.show() + return + + +def plot_common_embedding( + submissions_data, embedding_results, title="", pcs=(0, 1), save_path=None +): + all_embeddings = [] + labels = [] + pc1, pc2 = pcs + + for label, embedding in embedding_results["common_embedding"].items(): + all_embeddings.append(embedding) + labels.append(label) + + all_embeddings = torch.cat(all_embeddings, dim=0) + + weights = [] + for i in range(len(labels)): + weights += submissions_data[labels[i]]["populations"].numpy().tolist() + + weights = torch.tensor(weights) + weights = weights / weights.sum() + + if "gt_embedding" in embedding_results: + n_cols = min(3, len(labels) + 1) + n_rows = min((len(labels) + 1) // n_cols, 1) + + else: + n_cols = min(3, len(labels)) + n_rows = min(len(labels) // n_cols, 1) + + fig, ax = plt.subplots( + n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3), sharex=True, sharey=True + ) + if n_rows == 1 and n_cols == 1: + ax = np.array([ax]) + + for i in range(len(labels)): + sns.kdeplot( + x=all_embeddings[:, pc1], + y=all_embeddings[:, pc2], + cmap="viridis", + fill=True, + cbar=False, + ax=ax.flatten()[i], + weights=weights, + ) + + if "gt_embedding" in embedding_results: + sns.kdeplot( + x=all_embeddings[:, pc1], + y=all_embeddings[:, pc2], + cmap="viridis", + fill=True, + cbar=False, + ax=ax.flatten()[len(labels)], + weights=weights, + ) + + for i in range(len(labels)): + pops = submissions_data[labels[i]]["populations"].numpy() + pops = pops / pops.sum() + + ax.flatten()[i].scatter( + x=embedding_results["common_embedding"][labels[i]][:, pc1], + y=embedding_results["common_embedding"][labels[i]][:, pc2], + color="red", + s=pops / pops.max() * 200, + marker="o", + linewidth=0.3, + edgecolor="white", + label=labels[i], + ) + + ax.flatten()[i].set_xticks([]) + ax.flatten()[i].set_yticks([]) + + if i >= n_rows: + ax.flatten()[i].set_xlabel(f"Z{pc1 + 1}", fontsize=12) + if i % n_cols == 0: + ax.flatten()[i].set_ylabel(f"Z{pc2 + 1}", fontsize=12) + + i_max = i + + if "gt_embedding" in embedding_results: + i_max += 1 + + ax.flatten()[i_max].scatter( + x=embedding_results["gt_embedding"][:, pc1], + y=embedding_results["gt_embedding"][:, pc2], + color="red", + s=100, + marker="o", + linewidth=0.3, + edgecolor="white", + label="Ground Truth", + ) + + ax.flatten()[i_max].set_xlabel(f"Z{pc1 + 1}", fontsize=12) + ax.flatten()[i_max].set_ylabel(f"Z{pc2 + 1}", fontsize=12) + ax.flatten()[i_max].set_xticks([]) + ax.flatten()[i_max].set_yticks([]) + + if i_max < n_cols * n_rows: + for j in range(i_max + 1, n_cols * n_rows): + ax.flatten()[j].axis("off") + + # adjust horizontal space + plt.subplots_adjust(wspace=0.0, hspace=0.0) + + fig.suptitle(title, fontsize=16) + lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes] + lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] + fig.legend( + lines, labels, loc="center right", fontsize=12, bbox_to_anchor=(1.071, 0.5) + ) + + if save_path is not None: + plt.savefig(save_path) + + return + + +def plot_gt_embedding(submissions_data, gt_embedding_results, title="", save_path=None): + def gauss_pdf(x, mu, var): + return 1 / np.sqrt(2 * np.pi * var) * np.exp(-0.5 * (x - mu) ** 2 / var) + + def compute_gt_dist(z): + gauss1 = gauss_pdf(z, 150, 750) + gauss2 = 0.5 * gauss_pdf(z, 0, 500) + gauss3 = gauss_pdf(z, -150, 750) + return gauss1 + gauss2 + gauss3 + + n_cols = 3 + n_rows = len(list(submissions_data.keys())) // n_cols + 1 + + fig, ax = plt.subplots( + n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3), sharex=True, sharey=True + ) + + low_gt = -227.927103122416 + high_gt = 214.014930744738 + Z = np.linspace(low_gt, high_gt, gt_embedding_results["gt_embedding"].shape[0]) + x_axis = np.linspace( + torch.min(gt_embedding_results["gt_embedding"][:, 0]), + torch.max(gt_embedding_results["gt_embedding"][:, 0]), + gt_embedding_results["gt_embedding"].shape[0], + ) + + gt_dist = compute_gt_dist(Z) + gt_dist /= np.max(gt_dist) + + frq, edges = np.histogram(gt_embedding_results["gt_embedding"][:, 0], bins=20) + + i = 0 + for label, embedding in gt_embedding_results["submission_embedding"].items(): + ax.flatten()[i].bar( + edges[:-1], + frq / frq.max(), + width=np.diff(edges), + # label="Ground Truth", + alpha=0.8, + color="#a1c9f4", + ) + + ax.flatten()[i].plot(x_axis, gt_dist) # , label="True Distribution") + + populations = submissions_data[label]["populations"] + ax.flatten()[i].scatter( + x=embedding[:, 0], + y=populations / populations.max(), + color="red", + marker="o", + s=60, + linewidth=0.3, + edgecolor="white", + label=label, + ) + + # set x label only for the last row + if i >= n_rows: + ax.flatten()[i].set_xlabel("SVD 1", fontsize=12) + # set y label only for the first column + if i % n_cols == 0: + ax.flatten()[i].set_ylabel("Scaled probability", fontsize=12) + + # ax.flatten()[i].legend(loc="upper left", fontsize=12) + ax.flatten()[i].set_ylim(0.0, 1.1) + ax.flatten()[i].set_xlim(x_axis[0] * 1.3, x_axis[-1] * 1.3) + # set ticks to be maximum 5 ticks + ax.flatten()[i].set_yticks(np.arange(0, 1.25, 0.25)) + ax.flatten()[i].set_xticks([]) + + plt.subplots_adjust(wspace=0.0, hspace=0.0) + + i += 1 + + if i < n_cols * n_rows: + for j in range(i + 1, n_cols * n_rows): + ax.flatten()[j].axis("off") + + fig.suptitle(title, fontsize=16) + lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes] + lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] + fig.legend( + lines, labels, loc="center right", fontsize=12, bbox_to_anchor=(1.071, 0.5) + ) + + if save_path is not None: + plt.savefig(save_path) + + return diff --git a/src/cryo_challenge/_svd/svd_utils.py b/src/cryo_challenge/_svd/svd_utils.py index b040c3a..4cdaa48 100644 --- a/src/cryo_challenge/_svd/svd_utils.py +++ b/src/cryo_challenge/_svd/svd_utils.py @@ -1,73 +1,130 @@ import torch -from typing import Tuple - - -def get_vols_svd( - volumes: torch.tensor, -) -> Tuple[torch.tensor, torch.tensor, torch.tensor]: - """ - Compute the singular value decomposition of the input volumes. The volumes are flattened so that each volume is a row in the input matrix. - - Parameters - ---------- - volumes: torch.tensor - Tensor of shape (n_volumes, n_x, n_y, n_z) containing the volumes to be decomposed. - - Returns - ------- - U: torch.tensor - Left singular vectors of the input volumes. - S: torch.tensor - Singular values of the input volumes. - V: torch.tensor - Right singular vectors of the input volumes. - - Examples - -------- - >>> volumes = torch.randn(10, 32, 32, 32) - >>> U, S, V = get_vols_svd(volumes) - """ # noqa: E501 - assert volumes.ndim == 4, "Input volumes must have shape (n_volumes, n_x, n_y, n_z)" - assert volumes.shape[0] > 0, "Input volumes must have at least one volume" - - U, S, V = torch.svd_lowrank(volumes.reshape(volumes.shape[0], -1), q=40) - return U, S, V - - -def project_vols_to_svd( - volumes: torch.tensor, V_reference: torch.tensor -) -> torch.tensor: - """ - Project the input volumes onto the right singular vectors. - - Parameters - ---------- - volumes: torch.tensor - Tensor of shape (n_volumes, n_x, n_y, n_z) containing the volumes to be projected. - V_reference: torch.tensor - Right singular vectors of the reference volumes. - - Returns - ------- - coeffs: torch.tensor - Coefficients of the input volumes projected onto the right singular vectors. - - Examples - -------- - >>> volumes1 = torch.randn(10, 32, 32, 32) - >>> volumes2 = torch.randn(10, 32, 32, 32) - >>> U, S, V = get_vols_svd(volumes1) - >>> coeffs = project_vols_to_svd(volumes2, V) - """ # noqa: E501 - assert volumes.ndim == 4, "Input volumes must have shape (n_volumes, n_x, n_y, n_z)" - assert volumes.shape[0] > 0, "Input volumes must have at least one volume" - assert ( - V_reference.ndim == 2 - ), "Right singular vectors must have shape (n_features, n_components)" - assert ( - volumes.shape[1] * volumes.shape[2] * volumes.shape[3] == V_reference.shape[1] - ), "Number of features in volumes must match number of features in right singular vectors" # noqa: E501 - - coeffs = volumes.reshape(volumes.shape[0], -1) @ V_reference.T - - return coeffs +import numpy as np + + +### Compare subspaces for each submission ### + + +def captured_variance(V, U, S): + US = U @ torch.diag(S) + return torch.norm(torch.adjoint(V) @ US) / torch.norm(torch.adjoint(U) @ US) + + +# V US + + +def svd_metric(V1, V2, S1, S2): + return 0.5 * (captured_variance(V1, V2, S2) + captured_variance(V2, V1, S1)) + + +def sort_matrix_using_gt(dist_matrix: torch.Tensor, labels: np.ndarray): + sort_idx = torch.argsort(dist_matrix[:, -1]) + dist_matrix = dist_matrix[:, sort_idx][sort_idx] + labels = labels[sort_idx.numpy()] + + dist_matrix = torch.flip(dist_matrix, dims=(0,)) + dist_matrix = torch.flip(dist_matrix, dims=(1,)) + labels = np.flip(labels) + + return dist_matrix, labels + + +def compute_distance_matrix(submissions_data, gt_data=None): + n_subs = len(list(submissions_data.keys())) + labels = list(submissions_data.keys()) + dtype = submissions_data[labels[0]]["eigenvectors"].dtype + + if gt_data is not None: + dist_matrix = torch.ones((n_subs + 1, n_subs + 1), dtype=dtype) + else: + dist_matrix = torch.ones((n_subs, n_subs), dtype=dtype) + + for i, label1 in enumerate(labels): + for j, label2 in enumerate(labels[i:]): + dist_matrix[i, j + i] = svd_metric( + submissions_data[label1]["eigenvectors"], + submissions_data[label2]["eigenvectors"], + submissions_data[label1]["singular_values"], + submissions_data[label2]["singular_values"], + ) + dist_matrix[j + i, i] = dist_matrix[i, j + i] + + if gt_data is not None: + for i, label in enumerate(labels): + dist_matrix[i, n_subs] = svd_metric( + submissions_data[label]["eigenvectors"], + gt_data["eigenvectors"], + submissions_data[label]["singular_values"], + gt_data["singular_values"], + ) + dist_matrix[n_subs, i] = dist_matrix[i, n_subs] + + labels.append("Ground Truth") + labels = np.array(labels) + + dist_matrix, labels = sort_matrix_using_gt(dist_matrix, labels) + + labels = np.array(labels) + results = {"dist_matrix": dist_matrix, "labels": labels} + return results + + +### Compute common embedding ### + + +def compute_common_embedding(submissions_data, gt_data=None): + labels = list(submissions_data.keys()) + n_subs = len(labels) + shape_per_sub = submissions_data[labels[0]]["eigenvectors"].T.shape + dtype = submissions_data[labels[0]]["eigenvectors"].dtype + eigenvectors = torch.zeros( + (n_subs * shape_per_sub[0], shape_per_sub[1]), dtype=dtype + ) + + for i, label in enumerate(labels): + eigenvectors[i * shape_per_sub[0] : (i + 1) * shape_per_sub[0], :] = ( + submissions_data[label]["eigenvectors"].T + ) + + U, S, V = torch.linalg.svd(eigenvectors, full_matrices=False) + + Z_common = (U @ torch.diag(S)).reshape(n_subs, shape_per_sub[0], -1) + embeddings = {} + + for i, label in enumerate(labels): + Z_i = submissions_data[label]["u_matrices"] @ torch.diag( + submissions_data[label]["singular_values"] + ) + Z_i_common = torch.einsum("ij, jk -> ik", Z_i, Z_common[i]) + embeddings[labels[i]] = Z_i_common + + results = { + "common_embedding": embeddings, + } + + if gt_data is not None: + gt_proj = gt_data["u_matrices"] @ torch.diag(gt_data["singular_values"]) + gt_proj = gt_proj @ (gt_data["eigenvectors"].T @ V.T) + + results["gt_embedding"] = gt_proj + + return results + + +### Project to GT embedding ### +def project_to_gt_embedding(submissions_data, gt_data): + embedding_in_gt = {} + + for label, submission in submissions_data.items(): + projection = ( + submission["u_matrices"] + @ torch.diag(submission["singular_values"]) + @ (submission["eigenvectors"].T @ gt_data["eigenvectors"]) + ) + embedding_in_gt[label] = projection + + gt_coords = gt_data["u_matrices"] @ torch.diag(gt_data["singular_values"]) + + results = {"submission_embedding": embedding_in_gt, "gt_embedding": gt_coords} + + return results diff --git a/src/cryo_challenge/data/_io/svd_io_utils.py b/src/cryo_challenge/data/_io/svd_io_utils.py index 24b44cc..bc4a12f 100644 --- a/src/cryo_challenge/data/_io/svd_io_utils.py +++ b/src/cryo_challenge/data/_io/svd_io_utils.py @@ -1,153 +1,200 @@ import torch +import numpy as np from typing import Tuple +import os +from natsort import natsorted +import mrcfile -from ..._preprocessing.fourier_utils import downsample_volume +from ..._preprocessing.fourier_utils import downsample_volume, downsample_submission +from ..._preprocessing.bfactor_normalize import bfactor_normalize_volumes -def _remove_mean_volumes_sub(volumes, metadata): - box_size = volumes.shape[-1] - n_subs = len(list(metadata.keys())) - mean_volumes = torch.zeros((n_subs, box_size, box_size, box_size)) +def load_submissions_svd( + config: dict, +) -> Tuple[torch.tensor, dict]: + """ + Load the volumes and populations from the submissions specified in submission_list. Volumes are first downsampled, then normalized so that they sum to 1, and finally the mean volume is removed from each submission. - for i, key in enumerate(metadata.keys()): - indices = metadata[key]["indices"] + Parameters + ---------- + config: dict + Dictionary containing the configuration parameters. + Returns + ------- + submissions_data: dict + Dictionary containing the populations, left singular vectors, singular values, and right singular vectors of each submission. + """ # noqa: E501 - mean_volumes[i] = torch.mean(volumes[indices[0] : indices[1]], dim=0) - volumes[indices[0] : indices[1]] = ( - volumes[indices[0] : indices[1]] - mean_volumes[i][None, ...] - ) + path_to_submissions = config["path_to_submissions"] + excluded_submissions = config["excluded_submissions"] - return volumes, mean_volumes + submissions_data = {} + submission_files = [] + for file in os.listdir(path_to_submissions): + if file.endswith(".pt") and "submission" in file: + if file in excluded_submissions: + continue + submission_files.append(file) + submission_files = natsorted(submission_files) -def remove_mean_volumes(volumes, metadata=None): - volumes = volumes.clone() - if metadata is None: - mean_volumes = torch.mean(volumes, dim=0) - volumes = volumes - mean_volumes[None, ...] + vols = torch.load(os.path.join(path_to_submissions, submission_files[0]))["volumes"] + box_size = vols.shape[-1] - else: - volumes, mean_volumes = _remove_mean_volumes_sub(volumes, metadata) + if config["normalize_params"]["mask_path"] is not None: + mask = torch.tensor( + mrcfile.open(config["normalize_params"]["mask_path"], mode="r").data.copy() + ) + try: + mask = mask.reshape(1, box_size, box_size, box_size) + except RuntimeError: + raise ValueError( + "Mask shape does not match the box size of the volumes in the submissions." + ) + + for file in submission_files: + sub_path = os.path.join(path_to_submissions, file) + submission = torch.load(sub_path) + + label = submission["id"] + populations = submission["populations"] + + volumes = submission["volumes"] + if config["normalize_params"]["mask_path"] is not None: + volumes = volumes * mask + + if config["normalize_params"]["bfactor"] is not None: + voxel_size = config["voxel_size"] + volumes = bfactor_normalize_volumes( + volumes, + config["normalize_params"]["bfactor"], + voxel_size, + in_place=True, + ) + + if config["normalize_params"]["box_size_ds"] is not None: + volumes = downsample_submission( + volumes, box_size_ds=config["normalize_params"]["box_size_ds"] + ) + box_size = config["normalize_params"]["box_size_ds"] + else: + box_size = volumes.shape[-1] + + volumes = volumes.reshape(-1, box_size * box_size * box_size) + + if config["dtype"] == "float32": + volumes = volumes.float() + elif config["dtype"] == "float64": + volumes = volumes.double() + + volumes /= torch.norm(volumes, dim=1, keepdim=True) + + if config["svd_max_rank"] is None: + u_matrices, singular_values, eigenvectors = torch.linalg.svd( + volumes - volumes.mean(0, keepdim=True), full_matrices=False + ) + eigenvectors = eigenvectors.T + + else: + u_matrices, singular_values, eigenvectors = torch.svd_lowrank( + volumes - volumes.mean(0, keepdim=True), q=config["svd_max_rank"] + ) + + submissions_data[label] = { + "populations": torch.tensor(populations / populations.sum()), + "u_matrices": u_matrices.clone(), + "singular_values": singular_values.clone(), + "eigenvectors": eigenvectors.clone(), + } - return volumes, mean_volumes + return submissions_data -def load_volumes( - box_size_ds: float, - submission_list: list, - path_to_submissions: str, - dtype=torch.float32, -) -> Tuple[torch.tensor, dict]: +def load_gt_svd(config: dict) -> dict: """ - Load the volumes and populations from the submissions specified in submission_list. Volumes are first downsampled, then normalized so that they sum to 1, and finally the mean volume is removed from each submission. + Load the ground truth volumes, downsample them, normalize them, and remove the mean volume. Then compute the SVD of the volumes. Parameters ---------- - box_size_ds: float - Size of the downsampled box. - submission_list: list - List of submission indices to load. - path_to_submissions: str - Path to the directory containing the submissions. - dtype: torch.dtype - Data type of the volumes. + config: dict + Dictionary containing the configuration parameters. Returns ------- - volumes: torch.tensor - Tensor of shape (n_volumes, n_x, n_y, n_z) containing the volumes. - metadata: dict - Dictionary containing the metadata for each submission. - The keys are the id (ice cream name) of each submission. - The values are dictionaries containing the number of volumes, the populations, and the indices of the volumes in the volumes tensor. - - Examples - -------- - >>> box_size_ds = 64 - >>> submission_list = [0, 1, 2, 3, 4] # submission 5 is ignored - >>> path_to_submissions = "/path/to/submissions" # under this folder submissions should be name submission_i.pt - >>> volumes, populations = load_volumes(box_size_ds, submission_list, path_to_submissions) - """ # noqa: E501 + gt_data: dict + Dictionary containing the left singular vectors, singular values, and right singular vectors of the ground truth volumes. + """ - metadata = {} - volumes = torch.empty((0, box_size_ds, box_size_ds, box_size_ds), dtype=dtype) + vols_gt = np.load(config["gt_params"]["gt_vols_file"], mmap_mode="r") - counter = 0 + if len(vols_gt.shape) == 2: + box_size_gt = int(round((float(vols_gt.shape[-1]) ** (1.0 / 3.0)))) - for idx in submission_list: - submission = torch.load(f"{path_to_submissions}/submission_{idx}.pt") - vols = submission["volumes"] - pops = submission["populations"] + elif len(vols_gt.shape) == 4: + box_size_gt = vols_gt.shape[-1] - vols_tmp = torch.empty( - (vols.shape[0], box_size_ds, box_size_ds, box_size_ds), dtype=dtype + if config["normalize_params"]["box_size_ds"] is not None: + box_size = config["normalize_params"]["box_size_ds"] + else: + box_size = box_size_gt + + if config["normalize_params"]["mask_path"] is not None: + mask = torch.tensor( + mrcfile.open(config["normalize_params"]["mask_path"], mode="r").data.copy() ) - counter_start = counter - for j in range(vols.shape[0]): - vol_ds = downsample_volume(vols[j], box_size_ds) - vols_tmp[j] = vol_ds / vol_ds.sum() - counter += 1 - - metadata[submission["id"]] = { - "n_vols": vols.shape[0], - "populations": pops / pops.sum(), - "indices": (counter_start, counter), - } - volumes = torch.cat((volumes, vols_tmp), dim=0) + try: + mask = mask.reshape(box_size_gt, box_size_gt, box_size_gt) + except RuntimeError: + raise ValueError( + "Mask shape does not match the box size of the volumes in the submissions." + ) - return volumes, metadata + skip_vols = config["gt_params"]["skip_vols"] + n_vols = vols_gt.shape[0] // skip_vols + if config["dtype"] == "float32": + dtype = torch.float32 -def load_ref_vols(box_size_ds: int, path_to_volumes: str, dtype=torch.float32): - """ - Load the reference volumes, downsample them, normalize them, and remove the mean volume. + else: + dtype = torch.float64 - Parameters - ---------- - box_size_ds: int - Size of the downsampled box. - path_to_volumes: str - Path to the file containing the reference volumes. Must be in PyTorch format. - dtype: torch.dtype - Data type of the volumes. + volumes_gt = torch.zeros((n_vols, box_size * box_size * box_size), dtype=dtype) - Returns - ------- - volumes_ds: torch.tensor - Tensor of shape (n_volumes, n_x, n_y, n_z) containing the downsampled, normalized, and mean-removed reference volumes. - - Examples - -------- - >>> box_size_ds = 64 - >>> path_to_volumes = "/path/to/volumes.pt" - >>> volumes_ds = load_ref_vols(box_size_ds, path_to_volumes) - """ # noqa: E501 - try: - volumes = torch.load(path_to_volumes) - except (FileNotFoundError, EOFError): - raise ValueError("Volumes not found or not in PyTorch format.") - - # Reshape volumes to correct size - if volumes.dim() == 2: - box_size = int(round((float(volumes.shape[-1]) ** (1.0 / 3.0)))) - volumes = torch.reshape(volumes, (-1, box_size, box_size, box_size)) - elif volumes.dim() == 4: - pass - else: - raise ValueError( - f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape " - f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the " - f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size)." + for i in range(n_vols): + vol_tmp = torch.from_numpy( + vols_gt[i * skip_vols].copy().reshape(box_size_gt, box_size_gt, box_size_gt) ) - volumes_ds = torch.empty( - (volumes.shape[0], box_size_ds, box_size_ds, box_size_ds), dtype=dtype + if dtype == torch.float32: + vol_tmp = vol_tmp.float() + else: + vol_tmp = vol_tmp.double() + + if config["normalize_params"]["mask_path"] is not None: + vol_tmp *= mask + + if config["normalize_params"]["bfactor"] is not None: + bfactor = config["normalize_params"]["bfactor"] + voxel_size = config["voxel_size"] + vol_tmp = bfactor_normalize_volumes( + vol_tmp, bfactor, voxel_size, in_place=True + ) + + if config["normalize_params"]["box_size_ds"] is not None: + vol_tmp = downsample_volume(vol_tmp, box_size_ds=box_size) + + volumes_gt[i] = vol_tmp.reshape(-1) + volumes_gt /= torch.norm(volumes_gt, dim=1, keepdim=True) + + U, S, V = torch.svd_lowrank( + volumes_gt - volumes_gt.mean(0, keepdim=True), q=config["svd_max_rank"] ) - for i, vol in enumerate(volumes): - volumes_ds[i] = downsample_volume(vol, box_size_ds) - volumes_ds[i] = volumes_ds[i] / volumes_ds[i].sum() - volumes_ds = volumes_ds + gt_data = { + "u_matrices": U.clone(), + "singular_values": S.clone(), + "eigenvectors": V.clone(), + } - return volumes_ds + return gt_data diff --git a/src/cryo_challenge/data/_validation/config_validators.py b/src/cryo_challenge/data/_validation/config_validators.py index 98d6fc7..3e1d06f 100644 --- a/src/cryo_challenge/data/_validation/config_validators.py +++ b/src/cryo_challenge/data/_validation/config_validators.py @@ -1,6 +1,9 @@ from numbers import Number +import numpy as np import pandas as pd import os +from pydantic import BaseModel, validator, root_validator +from typing import Optional, List def validate_generic_config(config: dict, reference: dict) -> None: @@ -251,74 +254,122 @@ def validate_input_config_disttodist(config: dict) -> None: # SVD -def validate_config_svd_output(config_output: dict) -> None: - """ - Validate the output part of the config dictionary for the SVD pipeline. - """ # noqa: E501 - keys_and_types = { - "output_file": str, - "save_volumes": bool, - "save_svd_matrices": bool, - } - validate_generic_config(config_output, keys_and_types) - return - - -def validate_power_spectrum_normalization(config_normalization: dict) -> None: - """ - Validate the normalization part of the config dictionary for the SVD pipeline. - """ # noqa: E501 - keys_and_types = { - "ref_vol_key": str, - "ref_vol_index": Number, - } - validate_generic_config(config_normalization, keys_and_types) - return - - -def validate_config_svd(config: dict) -> None: - """ - Validate the config dictionary for the SVD pipeline. - """ # noqa: E501 - keys_and_types = { - "path_to_volumes": str, - "box_size_ds": Number, - "submission_list": list, - "experiment_mode": str, - "power_spectrum_normalization": dict, - "dtype": str, - "output_options": dict, - } - - validate_generic_config(config, keys_and_types) - validate_config_svd_output(config["output_options"]) - validate_power_spectrum_normalization(config["power_spectrum_normalization"]) - - if config["experiment_mode"] == "all_vs_ref": - if "path_to_reference" not in config.keys(): +class SVDNormalizeParams(BaseModel): + mask_path: Optional[str] = None + bfactor: float = None + box_size_ds: Optional[int] = None + + @validator("mask_path") + def check_mask_path_exists(cls, value): + if value is not None: + if not os.path.exists(value): + raise ValueError(f"Mask file {value} does not exist.") + return value + + @validator("bfactor") + def check_bfactor(cls, value): + if value is not None: + if value < 0: + raise ValueError("B-factor must be non-negative.") + return value + + @validator("box_size_ds") + def check_box_size_ds(cls, value): + if value is not None: + if value < 0: + raise ValueError("Downsampled box size must be non-negative.") + return value + + +class SVDGtParams(BaseModel): + gt_vols_file: str + skip_vols: int = 1 + + @validator("gt_vols_file") + def check_mask_path_exists(cls, value): + if not os.path.exists(value): + raise ValueError(f"Could not find file {value}.") + + assert value.endswith(".npy"), "Ground truth volumes file must be a .npy file." + + vols_gt = np.load(value, mmap_mode="r") + + if len(vols_gt.shape) not in [2, 4]: raise ValueError( - "Reference path is required for experiment mode 'all_vs_ref'" + "Invalid number of dimensions for the ground truth volumes" + ) + return value + + @validator("skip_vols") + def check_skip_vols(cls, value): + if value is not None: + if value < 0: + raise ValueError("Number of volumes to skip must be non-negative.") + return value + + +class SVDOutputParams(BaseModel): + output_file: str + save_svd_data: bool = False + generate_plots: bool = False + + +class SVDConfig(BaseModel): + # Main configuration fields + path_to_submissions: str + voxel_size: float + excluded_submissions: List[str] = [] + dtype: str = "float32" + svd_max_rank: Optional[int] = None + + # Subdictionaries + normalize_params: SVDNormalizeParams = SVDNormalizeParams() + gt_params: Optional[SVDGtParams] = None + output_params: SVDOutputParams + + @root_validator + def check_path_to_submissions(cls, values): + path_to_submissions = values.get("path_to_submissions") + excluded_submissions = values.get("excluded_submissions") + + if not os.path.exists(path_to_submissions): + raise ValueError(f"Could not find path {path_to_submissions}.") + + submission_files = [] + for file in os.listdir(path_to_submissions): + if file.endswith(".pt") and "submission" in file: + submission_files.append(file) + if len(submission_files) == 0: + raise ValueError(f"No submission files found in {path_to_submissions}.") + + submission_files = [] + for file in os.listdir(path_to_submissions): + if file.endswith(".pt") and "submission" in file: + if file in excluded_submissions: + continue + submission_files.append(file) + + if len(submission_files) == 0: + raise ValueError( + f"No submission files found after excluding {excluded_submissions}." ) - else: - assert isinstance(config["path_to_reference"], str) - os.path.exists(config["path_to_reference"]) - assert ( - "pt" in config["path_to_reference"] - ), "Reference path point to a .pt file" - - os.path.exists(config["path_to_volumes"]) - for submission in config["submission_list"]: - sub_path = os.path.join( - config["path_to_volumes"] + f"submission_{submission}.pt" - ) - os.path.exists(sub_path) - - assert config["dtype"] in [ - "float32", - "float64", - ], "dtype must be either 'float32' or 'float64'" - assert config["box_size_ds"] > 0, "box_size_ds must be greater than 0" - assert config["submission_list"] != [], "submission_list must not be empty" - - return + return values + + @validator("dtype") + def check_dtype(cls, value): + if value not in ["float32", "float64"]: + raise ValueError(f"Invalid dtype {value}.") + return value + + @validator("svd_max_rank") + def check_svd_max_rank(cls, value): + if value < 1 and value is not None: + raise ValueError("Max rank must be at least 1.") + return value + + @validator("voxel_size") + def check_voxel_size(cls, value): + if value <= 0: + raise ValueError("Voxel size must be positive.") + return value diff --git a/tests/config_files/test_config_svd.yaml b/tests/config_files/test_config_svd.yaml index 5a040b5..b24ad41 100644 --- a/tests/config_files/test_config_svd.yaml +++ b/tests/config_files/test_config_svd.yaml @@ -1,19 +1,21 @@ -path_to_volumes: tests/data/dataset_2_submissions/ -box_size_ds: 16 -submission_list: [1000] -experiment_mode: "all_vs_ref" # options are "all_vs_all", "all_vs_ref" -# optional unless experiment_mode is "all_vs_ref" +path_to_submissions: tests/data/dataset_2_submissions/ +#excluded_submissions: # you can exclude some submissions by filename +# - "submission_0.pt" +# - "submission_1.pt" -power_spectrum_normalization: - ref_vol_key: "Coffee" # which submission should be used - ref_vol_index: 0 # which volume of that submission should be used +dtype: float32 +svd_max_rank: 5 +voxel_size: 1.0 # voxel size of the input maps +normalize_params: # optional, if not given there will be no normalization + mask_path: tests/data/Ground_truth/test_mask_dilated_wide.mrc + bfactor: 170 + box_size_ds: 16 -path_to_reference: tests/data/Ground_truth/test_maps_gt_flat_10.pt -dtype: "float32" # options are "float32", "float64" -output_options: - # path will be created if it does not exist +gt_params: # optional, if provided there will be extra results + gt_vols_file: tests/data/Ground_truth/test_maps_gt_flat_10.npy + skip_vols: 1 + +output_params: output_file: tests/results/svd/svd_result.pt - # whether or not to save the processed volumes (downsampled, normalized, etc.) - save_volumes: True - # whether or not to save the SVD matrices (U, S, V) - save_svd_matrices: True + save_svd_data: True + generate_plots: True