diff --git a/README.md b/README.md index adc58686..9292a9c4 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,9 @@ + + +

🌟 Welcome to EthoML/VAME (Variational Animal Motion Encoding), an open-source machine learning tool for behavioral segmentation and analyses. diff --git a/examples/demo.ipynb b/examples/demo.ipynb index 17c59e83..5eb06e73 100644 --- a/examples/demo.ipynb +++ b/examples/demo.ipynb @@ -142,7 +142,7 @@ "outputs": [], "source": [ "# # OPTIONAL: Create behavioural hierarchies via community detection\n", - "vame.community(config, show_umap=False, cut_tree=2, cohort=True)" + "vame.community(config, cut_tree=2, cohort=False)" ] }, { diff --git a/examples/demo.py b/examples/demo.py index 379ecd03..3d15532d 100644 --- a/examples/demo.py +++ b/examples/demo.py @@ -56,7 +56,7 @@ # vame.motif_videos(config, videoType='.mp4') # # OPTIONAL: Create behavioural hierarchies via community detection -# vame.community(config, show_umap=False, cut_tree=2) +# vame.community(config, cut_tree=2) # # OPTIONAL: Create community videos to get insights about behavior on a hierarchical scale # vame.community_videos(config) diff --git a/pyproject.toml b/pyproject.toml index 4d11d4af..19d15ed8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vame-py" -version = '0.2.0' +version = '0.3.0' dynamic = ["dependencies"] description = "Variational Animal Motion Embedding." authors = [ diff --git a/src/vame/__init__.py b/src/vame/__init__.py index 86fdef2b..95d9a97d 100644 --- a/src/vame/__init__.py +++ b/src/vame/__init__.py @@ -24,5 +24,6 @@ from vame.analysis import gif from vame.util.csv_to_npy import csv_to_numpy from vame.util.align_egocentrical import egocentric_alignment +from vame.util import model_util from vame.util import auxiliary diff --git a/src/vame/analysis/__init__.py b/src/vame/analysis/__init__.py index 87e3ccdf..c8450399 100644 --- a/src/vame/analysis/__init__.py +++ b/src/vame/analysis/__init__.py @@ -14,7 +14,7 @@ from vame.analysis.pose_segmentation import pose_segmentation from vame.analysis.videowriter import motif_videos, community_videos from vame.analysis.community_analysis import community -from vame.analysis.umap_visualization import visualization +from vame.analysis.umap import visualization from vame.analysis.generative_functions import generative_model from vame.analysis.gif_creator import gif diff --git a/src/vame/analysis/community_analysis.py b/src/vame/analysis/community_analysis.py index 95f6cfc1..fb2cd479 100644 --- a/src/vame/analysis/community_analysis.py +++ b/src/vame/analysis/community_analysis.py @@ -12,7 +12,6 @@ """ import os -import umap import scipy import pickle import numpy as np @@ -20,6 +19,7 @@ import matplotlib.pyplot as plt from vame.util.auxiliary import read_config from vame.analysis.tree_hierarchy import graph_to_tree, draw_tree, traverse_tree_cutline +from vame.util.data_manipulation import consecutive from typing import List, Tuple from vame.schemas.states import save_state, CommunityFunctionSchema from vame.logging.logger import VameLogger @@ -87,21 +87,6 @@ def get_transition_matrix(adjacency_matrix: np.ndarray, threshold: float = 0.0) return transition_matrix -def consecutive(data: np.ndarray, stepsize: int = 1) -> List[np.ndarray]: - """Identifies location of missing motif finding consecutive elements in an array and returns array(s) at the split. - - Args: - data (np.ndarray): Input array. - stepsize (int, optional): Step size. Defaults to 1. - - Returns: - List[np.ndarray]: List of arrays containing consecutive elements. - """ - data = data[:] - return np.split(data, np.where(np.diff(data) != stepsize)[0]+1) - -# New find_zero_labels 8/8/2022 KL - def find_zero_labels(motif_usage: Tuple[np.ndarray, np.ndarray], n_cluster: int) -> np.ndarray: """Find zero labels in motif usage and fill them. @@ -311,7 +296,6 @@ def create_community_bag( return communities_all, trees def create_cohort_community_bag( - files: List[str], labels: List[np.ndarray], trans_mat_full: np.ndarray, cut_tree: int, @@ -321,7 +305,6 @@ def create_cohort_community_bag( (markov chain to tree -> community detection) Args: - files (List[str]): List of files paths (deprecated). labels (List[np.ndarray]): List of label arrays. trans_mat_full (np.ndarray): Full transition matrix. cut_tree (int): Cut line for tree. @@ -431,76 +414,11 @@ def get_cohort_community_labels( return community_labels_all -def umap_embedding(cfg: dict, file: str, model_name: str, n_cluster: int, parametrization: str) -> np.ndarray: - """Perform UMAP embedding for given file and parameters. - - Args: - cfg (dict): Configuration parameters. - file (str): File path. - model_name (str): Model name. - n_cluster (int): Number of clusters. - parametrization (str): parametrization. - - Returns: - np.ndarray: UMAP embedding. - """ - reducer = umap.UMAP(n_components=2, min_dist=cfg['min_dist'], n_neighbors=cfg['n_neighbors'], - random_state=cfg['random_state']) - - logger.info("UMAP calculation for file %s" %file) - - folder = os.path.join(cfg['project_path'],"results",file,model_name, parametrization +'-'+str(n_cluster),"") - latent_vector = np.load(os.path.join(folder,'latent_vector_'+file+'.npy')) - - num_points = cfg['num_points'] - if num_points > latent_vector.shape[0]: - num_points = latent_vector.shape[0] - logger.info("Embedding %d data points.." %num_points) - - embed = reducer.fit_transform(latent_vector[:num_points,:]) - - return embed - - -def umap_vis(cfg: dict, file: str, embed: np.ndarray, community_labels_all: np.ndarray, save_path: str | None) -> None: - """Create plotly visualizaton of UMAP embedding. - - Args: - cfg (dict): Configuration parameters. - file (str): File path. - embed (np.ndarray): UMAP embedding. - community_labels_all (np.ndarray): Community labels. - save_path: Path to save the plot. If None it will not save the plot. - - Returns: - None - """ - num_points = cfg['num_points'] - community_labels_all = np.asarray(community_labels_all) - if num_points > community_labels_all.shape[0]: - num_points = community_labels_all.shape[0] - logger.info("Embedding %d data points.." %num_points) - - num = np.unique(community_labels_all) - - fig = plt.figure(1) - plt.scatter(embed[:,0], embed[:,1], c=community_labels_all[:num_points], cmap='Spectral', s=2, alpha=1) - plt.colorbar(boundaries=np.arange(np.max(num)+2)-0.5).set_ticks(np.arange(np.max(num)+1)) - plt.gca().set_aspect('equal', 'datalim') - plt.grid(False) - - if save_path is not None: - plt.savefig(save_path) - return - plt.show() - @save_state(model=CommunityFunctionSchema) def community( config: str, cohort: bool = True, - show_umap: bool = False, cut_tree: int | None = None, - save_umap_figure: bool = True, save_logs: bool = False ) -> None: """Perform community analysis. @@ -508,7 +426,6 @@ def community( Args: config (str): Path to the configuration file. cohort (bool, optional): Flag indicating cohort analysis. Defaults to True. - show_umap (bool, optional): Flag indicating weather to show UMAP visualization. Defaults to False. cut_tree (int, optional): Cut line for tree. Defaults to None. Returns: @@ -551,7 +468,7 @@ def community( augmented_label, zero_motifs = augment_motif_timeseries(labels, n_cluster) _, trans_mat_full,_ = get_adjacency_matrix(augmented_label, n_cluster=n_cluster) _, usage_full = np.unique(augmented_label, return_counts=True) - communities_all, trees = create_cohort_community_bag(files, labels, trans_mat_full, cut_tree, n_cluster) + communities_all, trees = create_cohort_community_bag(labels, trans_mat_full, cut_tree, n_cluster) community_labels_all = get_cohort_community_labels(files, labels, communities_all) # community_bag = traverse_tree_cutline(trees, cutline=cut_tree) @@ -567,11 +484,6 @@ def community( with open(os.path.join(cfg['project_path'],"hierarchy"+".pkl"), "wb") as fp: #Pickling pickle.dump(communities_all, fp) - if show_umap: - embed = umap_embedding(cfg, files, model_name, n_cluster, parametrization) - # TODO fix umap vis for cohort and add save path - umap_vis(cfg, files, embed, community_labels_all) - # Work in Progress elif not cohort: labels = get_labels(cfg, files, model_name, n_cluster, parametrization) @@ -590,12 +502,6 @@ def community( with open(os.path.join(path_to_file,"community","hierarchy"+file+".pkl"), "wb") as fp: #Pickling pickle.dump(communities_all[idx], fp) - if show_umap: - embed = umap_embedding(cfg, file, model_name, n_cluster, parametrization) - umap_save_path = None - if save_umap_figure: - umap_save_path = os.path.join(path_to_file, "community", file + "_umap.png") - umap_vis(cfg, files, embed, community_labels_all[idx], save_path=umap_save_path) except Exception as e: logger.exception(f"Error in community_analysis: {e}") raise e diff --git a/src/vame/analysis/generative_functions.py b/src/vame/analysis/generative_functions.py index fb59fc1b..072ed583 100644 --- a/src/vame/analysis/generative_functions.py +++ b/src/vame/analysis/generative_functions.py @@ -16,8 +16,9 @@ from sklearn.mixture import GaussianMixture from vame.schemas.states import GenerativeModelFunctionSchema, save_state from vame.util.auxiliary import read_config -from vame.model.rnn_model import RNN_VAE from vame.logging.logger import VameLogger +from vame.util.model_util import load_model + logger_config = VameLogger(__name__) logger = logger_config.logger @@ -189,48 +190,6 @@ def visualize_cluster_center(cfg: dict, model: torch.nn.Module, cluster_center: return fig -def load_model(cfg: dict, model_name: str) -> torch.nn.Module: - """Load PyTorch model. - - Args: - cfg (dict): Configuration dictionary. - model_name (str): Name of the model. - - Returns: - torch.nn.Module: Loaded PyTorch model. - """ - ZDIMS = cfg['zdims'] - FUTURE_DECODER = cfg['prediction_decoder'] - TEMPORAL_WINDOW = cfg['time_window']*2 - FUTURE_STEPS = cfg['prediction_steps'] - - NUM_FEATURES = cfg['num_features'] - NUM_FEATURES = NUM_FEATURES - 2 - - hidden_size_layer_1 = cfg['hidden_size_layer_1'] - hidden_size_layer_2 = cfg['hidden_size_layer_2'] - hidden_size_rec = cfg['hidden_size_rec'] - hidden_size_pred = cfg['hidden_size_pred'] - dropout_encoder = cfg['dropout_encoder'] - dropout_rec = cfg['dropout_rec'] - dropout_pred = cfg['dropout_pred'] - softplus = cfg['softplus'] - - logger.info('Loading model... ') - - model = RNN_VAE(TEMPORAL_WINDOW,ZDIMS,NUM_FEATURES,FUTURE_DECODER,FUTURE_STEPS, hidden_size_layer_1, - hidden_size_layer_2, hidden_size_rec, hidden_size_pred, dropout_encoder, - dropout_rec, dropout_pred, softplus) - if torch.cuda.is_available(): - model = model.cuda() - else: - model = model.cpu() - - model.load_state_dict(torch.load(os.path.join(cfg['project_path'],'model','best_model',model_name+'_'+cfg['Project']+'.pkl'))) - model.eval() - - return model - @save_state(model=GenerativeModelFunctionSchema) def generative_model(config: str, mode: str = "sampling", save_logs: bool = False) -> plt.Figure: """Generative model. @@ -276,7 +235,7 @@ def generative_model(config: str, mode: str = "sampling", save_logs: bool = Fals files.append(all_flag) - model = load_model(cfg, model_name) + model = load_model(cfg, model_name, fixed=False) for file in files: path_to_file=os.path.join(cfg['project_path'],"results",file,model_name, parametrization + '-' +str(n_cluster),"") diff --git a/src/vame/analysis/pose_segmentation.py b/src/vame/analysis/pose_segmentation.py index 51ab5b4a..c0c760d5 100644 --- a/src/vame/analysis/pose_segmentation.py +++ b/src/vame/analysis/pose_segmentation.py @@ -16,69 +16,20 @@ import numpy as np from pathlib import Path from typing import List, Tuple - +from vame.util.data_manipulation import consecutive from hmmlearn import hmm from sklearn.cluster import KMeans from vame.schemas.states import save_state, PoseSegmentationFunctionSchema from vame.logging.logger import VameLogger, TqdmToLogger from vame.util.auxiliary import read_config from vame.model.rnn_model import RNN_VAE +from vame.util.model_util import load_model logger_config = VameLogger(__name__) logger = logger_config.logger -def load_model(cfg: dict, model_name: str, fixed: bool) -> RNN_VAE: - """Load the VAME model. - - Args: - cfg (dict): Configuration dictionary. - model_name (str): Name of the model. - fixed (bool): Fixed or variable length sequences. - - Returns: - RNN_VAE: Loaded VAME model. - """ - use_gpu = torch.cuda.is_available() - if use_gpu: - pass - else: - torch.device("cpu") - - # load Model - ZDIMS = cfg['zdims'] - FUTURE_DECODER = cfg['prediction_decoder'] - TEMPORAL_WINDOW = cfg['time_window']*2 - FUTURE_STEPS = cfg['prediction_steps'] - NUM_FEATURES = cfg['num_features'] - if not fixed: - NUM_FEATURES = NUM_FEATURES - 2 - hidden_size_layer_1 = cfg['hidden_size_layer_1'] - hidden_size_layer_2 = cfg['hidden_size_layer_2'] - hidden_size_rec = cfg['hidden_size_rec'] - hidden_size_pred = cfg['hidden_size_pred'] - dropout_encoder = cfg['dropout_encoder'] - dropout_rec = cfg['dropout_rec'] - dropout_pred = cfg['dropout_pred'] - softplus = cfg['softplus'] - - - if use_gpu: - model = RNN_VAE(TEMPORAL_WINDOW,ZDIMS,NUM_FEATURES,FUTURE_DECODER,FUTURE_STEPS, hidden_size_layer_1, - hidden_size_layer_2, hidden_size_rec, hidden_size_pred, dropout_encoder, - dropout_rec, dropout_pred, softplus).cuda() - else: - model = RNN_VAE(TEMPORAL_WINDOW,ZDIMS,NUM_FEATURES,FUTURE_DECODER,FUTURE_STEPS, hidden_size_layer_1, - hidden_size_layer_2, hidden_size_rec, hidden_size_pred, dropout_encoder, - dropout_rec, dropout_pred, softplus).to() - - model.load_state_dict(torch.load(os.path.join(cfg['project_path'],'model','best_model',model_name+'_'+cfg['Project']+'.pkl'))) - model.eval() - - return model - - def embedd_latent_vectors(cfg: dict, files: List[str], model: RNN_VAE, fixed: bool, tqdm_stream: TqdmToLogger | None) -> List[np.ndarray]: """Embed latent vectors for the given files using the VAME model. @@ -128,20 +79,6 @@ def embedd_latent_vectors(cfg: dict, files: List[str], model: RNN_VAE, fixed: bo return latent_vector_files -def consecutive(data: np.ndarray, stepsize: int = 1) -> List[np.ndarray]: - """Find consecutive sequences in the data array. - - Args: - data (np.ndarray): Input array. - stepsize (int, optional): Step size. Defaults to 1. - - Returns: - List[np.ndarray]: List of consecutive sequences. - """ - data = data[:] - return np.split(data, np.where(np.diff(data) != stepsize)[0]+1) - - def get_motif_usage(label: np.ndarray) -> np.ndarray: """Compute motif usage from the label array. diff --git a/src/vame/analysis/tree_hierarchy.py b/src/vame/analysis/tree_hierarchy.py index 3c72d6df..3040da45 100644 --- a/src/vame/analysis/tree_hierarchy.py +++ b/src/vame/analysis/tree_hierarchy.py @@ -299,7 +299,6 @@ def graph_to_tree( return T - def draw_tree(T: nx.Graph) -> None: """ Draw a tree. @@ -317,110 +316,6 @@ def draw_tree(T: nx.Graph) -> None: figManager = plt.get_current_fig_manager() #figManager.window.showMaximized() - - -def traverse_tree(T: nx.Graph, root_node: str | None = None) -> str: - # TODO duplicated function def - """ - Traverse a tree and return the traversal sequence. - - Args: - T (nx.Graph): The tree to be traversed. - root_node (str, optional): The root node of the tree. If None, traversal starts from the root. - - Returns: - str: The traversal sequence. - """ - if root_node == None: - node=['Root'] - else: - node=[root_node] - traverse_list = [] - traverse_preorder = '{' - - def _traverse_tree(T, node, traverse_preorder): - traverse_preorder += str(node[0]) - traverse_list.append(node[0]) - children = list(T.neighbors(node[0])) - - if len(children) == 3: - # print(children) - for child in children: - if child in traverse_list: -# print(child) - children.remove(child) - - if len(children) > 1: - traverse_preorder += '{' - traverse_preorder_temp = _traverse_tree(T, [children[0]], '') - traverse_preorder += traverse_preorder_temp - - traverse_preorder += '}{' - - traverse_preorder_temp = _traverse_tree(T, [children[1]], '') - traverse_preorder += traverse_preorder_temp - traverse_preorder += '}' - - return traverse_preorder - - traverse_preorder = _traverse_tree(T, node, traverse_preorder) - traverse_preorder += '}' - - return traverse_preorder - - - -def _traverse_tree(T: nx.Graph, node: List[str], traverse_preorder: str, traverse_list: List[str]) -> str: - # TODO duplicated function def - # Aux function for traverse_tree - traverse_preorder += str(node[0]) - traverse_list.append(node[0]) - children = list(T.neighbors(node[0])) - - if len(children) == 3: -# print(children) - for child in children: - if child in traverse_list: -# print(child) - children.remove(child) - - if len(children) > 1: - traverse_preorder += '{' - traverse_preorder_temp = _traverse_tree(T, [children[0]], '',traverse_list) - traverse_preorder += traverse_preorder_temp - - traverse_preorder += '}{' - - traverse_preorder_temp = _traverse_tree(T, [children[1]], '',traverse_list) - traverse_preorder += traverse_preorder_temp - traverse_preorder += '}' - - return traverse_preorder - -def traverse_tree(T: nx.Graph, root_node: str | None = None) -> str: - # TODO duplicated function def - """ - Traverse a tree and return the traversal sequence. - - Args: - T (nx.Graph): The tree to be traversed. - root_node (str, optional): The root node of the tree. If None, traversal starts from the root. - - Returns: - str: The traversal sequence. - """ - if root_node == None: - node=['Root'] - else: - node=[root_node] - traverse_list = [] - traverse_preorder = '{' - traverse_preorder = _traverse_tree(T, node, traverse_preorder,traverse_list) - traverse_preorder += '}' - - return traverse_preorder - - def _traverse_tree_cutline( T: nx.Graph, node: List[str], diff --git a/src/vame/analysis/umap_visualization.py b/src/vame/analysis/umap.py similarity index 65% rename from src/vame/analysis/umap_visualization.py rename to src/vame/analysis/umap.py index f5c76a5b..ff827c75 100644 --- a/src/vame/analysis/umap_visualization.py +++ b/src/vame/analysis/umap.py @@ -24,12 +24,87 @@ logger = logger_config.logger -def umap_vis(file: str, embed: np.ndarray, num_points: int) -> None: +def umap_embedding(cfg: dict, file: str, model_name: str, n_cluster: int, parametrization: str) -> np.ndarray: + """Perform UMAP embedding for given file and parameters. + + Args: + cfg (dict): Configuration parameters. + file (str): File path. + model_name (str): Model name. + n_cluster (int): Number of clusters. + parametrization (str): parametrization. + + Returns: + np.ndarray: UMAP embedding. + """ + reducer = umap.UMAP( + n_components=2, + min_dist=cfg['min_dist'], + n_neighbors=cfg['n_neighbors'], + random_state=cfg['random_state'] + ) + + logger.info("UMAP calculation for file %s" %file) + + folder = os.path.join(cfg['project_path'],"results",file,model_name, parametrization +'-'+str(n_cluster),"") + latent_vector = np.load(os.path.join(folder,'latent_vector_'+file+'.npy')) + + num_points = cfg['num_points'] + if num_points > latent_vector.shape[0]: + num_points = latent_vector.shape[0] + logger.info("Embedding %d data points.." %num_points) + + embed = reducer.fit_transform(latent_vector[:num_points,:]) + np.save(os.path.join(folder,"community","umap_embedding_"+file+'.npy'), embed) + + return embed + + +def umap_vis_community_labels(cfg: dict, embed: np.ndarray, community_labels_all: np.ndarray, save_path: str | None) -> None: + """Create plotly visualizaton of UMAP embedding with community labels. + + Args: + cfg (dict): Configuration parameters. + embed (np.ndarray): UMAP embedding. + community_labels_all (np.ndarray): Community labels. + save_path: Path to save the plot. If None it will not save the plot. + + Returns: + None + """ + num_points = cfg['num_points'] + community_labels_all = np.asarray(community_labels_all) + if num_points > community_labels_all.shape[0]: + num_points = community_labels_all.shape[0] + logger.info("Embedding %d data points.." %num_points) + + num = np.unique(community_labels_all) + + fig = plt.figure(1) + plt.scatter( + embed[:,0], + embed[:,1], + c=community_labels_all[:num_points], + cmap='Spectral', + s=2, + alpha=1 + ) + plt.colorbar(boundaries=np.arange(np.max(num)+2)-0.5).set_ticks(np.arange(np.max(num)+1)) + plt.gca().set_aspect('equal', 'datalim') + plt.grid(False) + + if save_path is not None: + plt.savefig(save_path) + return fig + plt.show() + return fig + + +def umap_vis(embed: np.ndarray, num_points: int) -> None: """ Visualize UMAP embedding without labels. Args: - file (str): Name of the file (deprecated). embed (np.ndarray): UMAP embedding. num_points (int): Number of data points to visualize. @@ -46,12 +121,11 @@ def umap_vis(file: str, embed: np.ndarray, num_points: int) -> None: return fig -def umap_label_vis(file: str, embed: np.ndarray, label: np.ndarray, n_cluster: int, num_points: int) -> None: +def umap_label_vis(embed: np.ndarray, label: np.ndarray, n_cluster: int, num_points: int) -> None: """ Visualize UMAP embedding with motif labels. Args: - file (str): Name of the file (deprecated). embed (np.ndarray): UMAP embedding. label (np.ndarray): Motif labels. n_cluster (int): Number of clusters. @@ -68,12 +142,11 @@ def umap_label_vis(file: str, embed: np.ndarray, label: np.ndarray, n_cluster: i return fig -def umap_vis_comm(file: str, embed: np.ndarray, community_label: np.ndarray, num_points: int) -> None: +def umap_vis_comm(embed: np.ndarray, community_label: np.ndarray, num_points: int) -> None: """ Visualize UMAP embedding with community labels. Args: - file (str): Name of the file (deprecated). embed (np.ndarray): UMAP embedding. community_label (np.ndarray): Community labels. num_points (int): Number of data points to visualize. @@ -151,36 +224,27 @@ def visualization( if not os.path.exists(os.path.join(path_to_file,"community")): os.mkdir(os.path.join(path_to_file,"community")) logger.info("Compute embedding for file %s" %file) - reducer = umap.UMAP(n_components=2, min_dist=cfg['min_dist'], n_neighbors=cfg['n_neighbors'], - random_state=cfg['random_state']) - - latent_vector = np.load(os.path.join(path_to_file,"",'latent_vector_'+file+'.npy')) - + embed = umap_embedding(cfg, file, model_name, n_cluster, param) num_points = cfg['num_points'] - if num_points > latent_vector.shape[0]: - num_points = latent_vector.shape[0] - logger.info("Embedding %d data points.." %num_points) - - embed = reducer.fit_transform(latent_vector[:num_points,:]) - np.save(os.path.join(path_to_file,"community","umap_embedding_"+file+'.npy'), embed) + if num_points > embed.shape[0]: + num_points = embed.shape[0] - logger.info("Visualizing %d data points.. " %num_points) if label is None: - output_figure = umap_vis(file, embed, num_points) + output_figure = umap_vis(embed, num_points) fig_path = os.path.join(path_to_file,"community","umap_vis_label_none_"+file+".png") output_figure.savefig(fig_path) return output_figure if label == 'motif': motif_label = np.load(os.path.join(path_to_file,"",str(n_cluster)+'_' + param + '_label_'+file+'.npy')) - output_figure = umap_label_vis(file, embed, motif_label, n_cluster, num_points) + output_figure = umap_label_vis(embed, motif_label, n_cluster, num_points) fig_path = os.path.join(path_to_file,"community","umap_vis_motif_"+file+".png") output_figure.savefig(fig_path) return output_figure if label == "community": community_label = np.load(os.path.join(path_to_file,"","community","","community_label_"+file+".npy")) - output_figure = umap_vis_comm(file, embed, community_label, num_points) + output_figure = umap_vis_comm(embed, community_label, num_points) fig_path = os.path.join(path_to_file,"community","umap_vis_community_"+file+".png") output_figure.savefig(fig_path) return output_figure diff --git a/src/vame/initialize_project/new.py b/src/vame/initialize_project/new.py index debf8a9a..00c5df95 100644 --- a/src/vame/initialize_project/new.py +++ b/src/vame/initialize_project/new.py @@ -22,7 +22,7 @@ from pathlib import Path import shutil from datetime import datetime as dt -from vame.util import auxiliary +from vame.util.auxiliary import write_config from typing import List from vame.schemas.project import ProjectSchema from vame.schemas.states import VAMEPipelineStatesSchema @@ -141,7 +141,7 @@ def init_new_project( projconfigfile=os.path.join(str(project_path), 'config.yaml') # Write dictionary to yaml config file - auxiliary.write_config(projconfigfile, cfg_data) + write_config(projconfigfile, cfg_data) vame_pipeline_default_schema = VAMEPipelineStatesSchema() vame_pipeline_default_schema_path = Path(project_path) / 'states/states.json' diff --git a/src/vame/model/create_training.py b/src/vame/model/create_training.py index b499f9e2..180deae1 100644 --- a/src/vame/model/create_training.py +++ b/src/vame/model/create_training.py @@ -19,41 +19,13 @@ from vame.logging.logger import VameLogger from vame.util.auxiliary import read_config from vame.schemas.states import CreateTrainsetFunctionSchema, save_state +from vame.util.data_manipulation import interpol_all_nans logger_config = VameLogger(__name__) logger = logger_config.logger -def nan_helper(y: np.ndarray) -> Tuple: - """ - Identifies indices of NaN values in an array and provides a function to convert them to non-NaN indices. - - Args: - y (np.ndarray): Input array containing NaN values. - - Returns: - Tuple[np.ndarray, Union[np.ndarray, None]]: A tuple containing two elements: - - An array of boolean values indicating the positions of NaN values. - - A lambda function to convert NaN indices to non-NaN indices. - """ - return np.isnan(y), lambda z: z.nonzero()[0] - -def interpol(arr: np.ndarray) -> np.ndarray: - """ - Interpolates all NaN values in the given array. - - Args: - arr (np.ndarray): Input array containing NaN values. - - Returns: - np.ndarray: Array with NaN values replaced by interpolated values. - """ - y = np.transpose(arr) - nans, x = nan_helper(y) - y[nans]= np.interp(x(nans), x(~nans), y[~nans]) - arr = np.transpose(y) - return arr def plot_check_parameter( cfg: dict, @@ -61,8 +33,6 @@ def plot_check_parameter( num_frames: int, X_true: List[np.ndarray], X_med: np.ndarray, - anchor_1: int | None = None, - anchor_2: int | None = None ) -> None: """ Plot the check parameter - z-scored data and the filtered data. @@ -141,7 +111,6 @@ def traindata_aligned( cfg: dict, files: List[str], testfraction: float, - num_features: int, savgol_filter: bool, check_parameter: bool ) -> None: @@ -202,7 +171,7 @@ def traindata_aligned( elif X_z[i,marker] < -cfg['iqr_factor']*iqr_val: X_z[i,marker] = np.nan - X_z = interpol(X_z) + X_z = interpol_all_nans(X_z) X_len = len(data.T) pos_temp += X_len @@ -248,7 +217,7 @@ def traindata_aligned( z_train = X_med[:,test:] if check_parameter: - plot_check_parameter(cfg, iqr_val, num_frames, X_true, X_med, anchor_1, anchor_2) + plot_check_parameter(cfg, iqr_val, num_frames, X_true, X_med) else: #save numpy arrays the the test/train info: @@ -319,7 +288,7 @@ def traindata_fixed( elif X_z[i,marker] < -cfg['iqr_factor']*iqr_val: X_z[i,marker] = np.nan - X_z[i,:] = interpol(X_z[i,:]) + X_z[i,:] = interpol_all_nans(X_z[i,:]) X_len = len(data.T) pos_temp += X_len @@ -416,7 +385,7 @@ def create_trainset( if not fixed: logger.info("Creating trainset from the vame.egocentrical_alignment() output ") - traindata_aligned(cfg, files, cfg['test_fraction'], cfg['num_features'], cfg['savgol_filter'], check_parameter) + traindata_aligned(cfg, files, cfg['test_fraction'], cfg['savgol_filter'], check_parameter) else: logger.info("Creating trainset from the vame.csv_to_numpy() output ") traindata_fixed(cfg, files, cfg['test_fraction'], cfg['num_features'], cfg['savgol_filter'], check_parameter, pose_ref_index) diff --git a/src/vame/model/evaluate.py b/src/vame/model/evaluate.py index 6a9eb27e..7dce4405 100644 --- a/src/vame/model/evaluate.py +++ b/src/vame/model/evaluate.py @@ -22,18 +22,16 @@ from vame.model.dataloader import SEQUENCE_DATASET from vame.logging.logger import VameLogger + logger_config = VameLogger(__name__) logger = logger_config.logger - - use_gpu = torch.cuda.is_available() if use_gpu: pass else: torch.device("cpu") - def plot_reconstruction( filepath: str, test_loader: Data.DataLoader, diff --git a/src/vame/model/rnn_model.py b/src/vame/model/rnn_model.py index a0947886..2564efb1 100644 --- a/src/vame/model/rnn_model.py +++ b/src/vame/model/rnn_model.py @@ -65,7 +65,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: class Lambda(nn.Module): """Lambda module for computing the latent space parameters.""" - def __init__(self, ZDIMS: int, hidden_size_layer_1: int, hidden_size_layer_2: int, softplus: bool): + def __init__(self, ZDIMS: int, hidden_size_layer_1: int, softplus: bool): """ Initialize the Lambda module. @@ -271,7 +271,7 @@ def __init__( self.FUTURE_DECODER = FUTURE_DECODER self.seq_len = int(TEMPORAL_WINDOW / 2) self.encoder = Encoder(NUM_FEATURES, hidden_size_layer_1, hidden_size_layer_2, dropout_encoder) - self.lmbda = Lambda(ZDIMS, hidden_size_layer_1, hidden_size_layer_2, softplus) + self.lmbda = Lambda(ZDIMS, hidden_size_layer_1, softplus) self.decoder = Decoder(self.seq_len,ZDIMS,NUM_FEATURES, hidden_size_rec, dropout_rec) if FUTURE_DECODER: self.decoder_future = Decoder_Future(self.seq_len,ZDIMS,NUM_FEATURES,FUTURE_STEPS, hidden_size_pred, diff --git a/src/vame/model/rnn_vae.py b/src/vame/model/rnn_vae.py index 6dffa49f..4e27a623 100644 --- a/src/vame/model/rnn_vae.py +++ b/src/vame/model/rnn_vae.py @@ -277,9 +277,7 @@ def train( def test( test_loader: Data.DataLoader, - epoch: int, model: nn.Module, - optimizer: torch.optim.Optimizer, BETA: float, kl_weight: float, seq_len: int, @@ -501,7 +499,7 @@ def train_model(config: str, save_logs: bool = False) -> None: FUTURE_STEPS, scheduler, MSE_REC_REDUCTION, MSE_PRED_REDUCTION, KMEANS_LOSS, KMEANS_LAMBDA, TRAIN_BATCH_SIZE, noise) - current_loss, test_loss, test_list = test(test_loader, epoch, model, optimizer, + current_loss, test_loss, test_list = test(test_loader, model, BETA, weight, TEMPORAL_WINDOW, MSE_REC_REDUCTION, KMEANS_LOSS, KMEANS_LAMBDA, FUTURE_DECODER, TEST_BATCH_SIZE) diff --git a/src/vame/schemas/project.py b/src/vame/schemas/project.py index d68d3a12..b594b004 100644 --- a/src/vame/schemas/project.py +++ b/src/vame/schemas/project.py @@ -87,7 +87,4 @@ class ProjectSchema(BaseModel): kl_start: int = Field(default=2, title='KL start') annealtime: int = Field(default=4, title='Annealtime') - # Legacy mode - legacy: bool = Field(default=False, title='Legacy mode') - model_config: ConfigDict = ConfigDict(protected_namespaces=()) \ No newline at end of file diff --git a/src/vame/schemas/states.py b/src/vame/schemas/states.py index d9311492..96cbe8c8 100644 --- a/src/vame/schemas/states.py +++ b/src/vame/schemas/states.py @@ -9,6 +9,7 @@ class StatesEnum(str, Enum): success = 'success' failed = 'failed' running = 'running' + aborted = 'aborted' class GenerativeModelModeEnum(str, Enum): sampling = 'sampling' @@ -56,7 +57,6 @@ class MotifVideosFunctionSchema(BaseStateSchema): class CommunityFunctionSchema(BaseStateSchema): cohort: bool = Field(title='Cohort', default=True) - show_umap: bool = Field(title='Show UMAP', default=False) cut_tree: int | None = Field(title='Cut tree', default=None) @@ -135,5 +135,8 @@ def wrapper(*args, **kwargs): except Exception as e: _save_state(kwargs_model, function_name, state=StatesEnum.failed) raise e + except KeyboardInterrupt as e: + _save_state(kwargs_model, function_name, state=StatesEnum.aborted) + raise e return wrapper return decorator \ No newline at end of file diff --git a/src/vame/util/align_egocentrical.py b/src/vame/util/align_egocentrical.py index da90b430..fbc1bbbb 100644 --- a/src/vame/util/align_egocentrical.py +++ b/src/vame/util/align_egocentrical.py @@ -17,188 +17,16 @@ from pathlib import Path from vame.util.auxiliary import read_config from vame.schemas.states import EgocentricAlignmentFunctionSchema, save_state +from vame.util.data_manipulation import ( + interpol_first_rows_nans, + crop_and_flip, + background +) logger_config = VameLogger(__name__) logger = logger_config.logger -#Returns cropped image using rect tuple -def crop_and_flip( - rect: Tuple, - src: np.ndarray, - points: List[np.ndarray], - ref_index: Tuple[int, int] -) -> Tuple[np.ndarray, List[np.ndarray]]: - """ - Crop and flip the image based on the given rectangle and points. - - Args: - rect (Tuple): Rectangle coordinates (center, size, theta). - src (np.ndarray): Source image. - points (List[np.ndarray]): List of points. - ref_index (Tuple[int, int]): Reference indices for alignment. - - Returns: - Tuple[np.ndarray, List[np.ndarray]]: Cropped and flipped image, and shifted points. - """ - #Read out rect structures and convert - center, size, theta = rect - - center, size = tuple(map(int, center)), tuple(map(int, size)) - - # center_lst = list(center) - # center_lst[0] = center[0] - size[0]//2 - # center_lst[1] = center[1] - size[1]//2 - - # center = tuple(center_lst) - - - # center[0] -= size[0]//2 - # center[1] -= size[0]//2 # added this shift to change center to belly 2/28/2024 - - #Get rotation matrix - M = cv.getRotationMatrix2D(center, theta, 1) - - #shift DLC points - x_diff = center[0] - size[0]//2 - y_diff = center[1] - size[1]//2 - - # x_diff = center[0] - # y_diff = center[1] - - dlc_points_shifted = [] - - for i in points: - point=cv.transform(np.array([[[i[0], i[1]]]]),M)[0][0] - - point[0] -= x_diff - point[1] -= y_diff - - dlc_points_shifted.append(point) - - # Perform rotation on src image - dst = cv.warpAffine(src.astype('float32'), M, src.shape[:2]) - out = cv.getRectSubPix(dst, size, center) - - #check if flipped correctly, otherwise flip again - if dlc_points_shifted[ref_index[1]][0] >= dlc_points_shifted[ref_index[0]][0]: - rect = ((size[0]//2,size[0]//2),size,180) #should second value be size[1]? Is this relevant to the flip? 3/5/24 KKL - center, size, theta = rect - center, size = tuple(map(int, center)), tuple(map(int, size)) - - - # center_lst = list(center) - # center_lst[0] = center[0] - size[0]//2 - # center_lst[1] = center[1] - size[1]//2 - # center = tuple(center_lst) - - #Get rotation matrix - M = cv.getRotationMatrix2D(center, theta, 1) - - - #shift DLC points - x_diff = center[0] - size[0]//2 - y_diff = center[1] - size[1]//2 - - # x_diff = center[0] - # y_diff = center[1] - - - points = dlc_points_shifted - dlc_points_shifted = [] - - for i in points: - point=cv.transform(np.array([[[i[0], i[1]]]]),M)[0][0] - - point[0] -= x_diff - point[1] -= y_diff - - dlc_points_shifted.append(point) - - # Perform rotation on src image - dst = cv.warpAffine(out.astype('float32'), M, out.shape[:2]) - out = cv.getRectSubPix(dst, size, center) - - return out, dlc_points_shifted - - -def nan_helper(y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """ - Helper function to identify NaN values in an array. - - Args: - y (np.ndarray): Input array. - - Returns: - Tuple[np.ndarray, np.ndarray]: Boolean mask for NaN values and function to interpolate them. - """ - return np.isnan(y), lambda z: z.nonzero()[0] - - -def interpol(arr: np.ndarray) -> np.ndarray: - """ - Interpolates NaN values in the given array. - - Args: - arr (np.ndarray): Input array. - - Returns: - np.ndarray: Array with interpolated NaN values. - """ - - y = np.transpose(arr) - - nans, x = nan_helper(y[0]) - y[0][nans]= np.interp(x(nans), x(~nans), y[0][~nans]) - nans, x = nan_helper(y[1]) - y[1][nans]= np.interp(x(nans), x(~nans), y[1][~nans]) - - arr = np.transpose(y) - - return arr - -def background(path_to_file: str, filename: str, video_format: str = '.mp4', num_frames: int = 1000) -> np.ndarray: - """ - Compute the background image from a fixed camera. - - Args: - path_to_file (str): Path to the file directory. - filename (str): Name of the video file without the format. - video_format (str, optional): Format of the video file. Defaults to '.mp4'. - num_frames (int, optional): Number of frames to use for background computation. Defaults to 1000. - - Returns: - np.ndarray: Background image. - """ - import scipy.ndimage - capture = cv.VideoCapture(os.path.join(path_to_file,'videos',filename+video_format)) - - if not capture.isOpened(): - raise Exception("Unable to open video file: {0}".format(os.path.join(path_to_file,'videos',filename+video_format))) - - frame_count = int(capture.get(cv.CAP_PROP_FRAME_COUNT)) - ret, frame = capture.read() - - height, width, _ = frame.shape - frames = np.zeros((height,width,num_frames)) - - for i in tqdm.tqdm(range(num_frames), disable=not True, desc='Compute background image for video %s' %filename): - rand = np.random.choice(frame_count, replace=False) - capture.set(1,rand) - ret, frame = capture.read() - gray = cv.cvtColor(frame, cv.COLOR_BGR2GRAY) - frames[...,i] = gray - - logger.info('Finishing up!') - medFrame = np.median(frames,2) - background = scipy.ndimage.median_filter(medFrame, (5,5)) - - # np.save(path_to_file+'videos/'+'background/'+filename+'-background.npy',background) - - capture.release() - return background - - def align_mouse( path_to_file: str, filename: str, @@ -243,7 +71,7 @@ def align_mouse( for i in pose_list: - i = interpol(i) + i = interpol_first_rows_nans(i) if use_video: capture = cv.VideoCapture(os.path.join(path_to_file,'videos',filename+video_format)) @@ -424,7 +252,7 @@ def alignment( if use_video: #compute background - bg = background(path_to_file,filename,video_format) + bg = background(path_to_file,filename,video_format, save_background=False) capture = cv.VideoCapture(os.path.join(path_to_file,'videos',filename+video_format)) if not capture.isOpened(): raise Exception("Unable to open video file: {0}".format(os.path.join(path_to_file,'videos',filename+video_format))) diff --git a/src/vame/util/csv_to_npy.py b/src/vame/util/csv_to_npy.py index 228fe0d2..5a6a1c95 100644 --- a/src/vame/util/csv_to_npy.py +++ b/src/vame/util/csv_to_npy.py @@ -18,49 +18,13 @@ from typing import Tuple from vame.schemas.states import CsvToNumpyFunctionSchema, save_state from vame.logging.logger import VameLogger +from vame.util.data_manipulation import interpol_first_rows_nans logger_config = VameLogger(__name__) logger = logger_config.logger -def nan_helper(y: np.ndarray) -> Tuple: - """ - Identifies indices of NaN values in an array and provides a function to convert them to non-NaN indices. - - Args: - y (np.ndarray): Input array containing NaN values. - - Returns: - Tuple[np.ndarray, Union[np.ndarray, None]]: A tuple containing two elements: - - An array of boolean values indicating the positions of NaN values. - - A lambda function to convert NaN indices to non-NaN indices. - """ - return np.isnan(y), lambda z: z.nonzero()[0] - - - -def interpol(arr: np.ndarray) -> np.ndarray: - """Interpolates all NaN values of a given array. - - Args: - arr (np.ndarray): A numpy array with NaN values. - - Return: - np.ndarray: A numpy array with interpolated NaN values. - """ - - y = np.transpose(arr) - - nans, x = nan_helper(y[0]) - y[0][nans]= np.interp(x(nans), x(~nans), y[0][~nans]) - nans, x = nan_helper(y[1]) - y[1][nans]= np.interp(x(nans), x(~nans), y[1][~nans]) - - arr = np.transpose(y) - - return arr - @save_state(model=CsvToNumpyFunctionSchema) def csv_to_numpy(config: str, save_logs=False) -> None: @@ -105,7 +69,7 @@ def csv_to_numpy(config: str, save_logs=False) -> None: # interpolate NaNs for i in pose_list: - i = interpol(i) + i = interpol_first_rows_nans(i) positions = np.concatenate(pose_list, axis=1) final_positions = np.zeros((data_mat.shape[0], int(data_mat.shape[1]/3)*2)) diff --git a/src/vame/util/data_manipulation.py b/src/vame/util/data_manipulation.py new file mode 100644 index 00000000..21d05bbe --- /dev/null +++ b/src/vame/util/data_manipulation.py @@ -0,0 +1,201 @@ +import numpy as np +from typing import List, Tuple +import cv2 as cv +import os +from scipy.ndimage import median_filter +import tqdm +from vame.logging.logger import VameLogger + + +logger_config = VameLogger(__name__) +logger = logger_config.logger + +def consecutive(data: np.ndarray, stepsize: int = 1) -> List[np.ndarray]: + """Find consecutive sequences in the data array. + + Args: + data (np.ndarray): Input array. + stepsize (int, optional): Step size. Defaults to 1. + + Returns: + List[np.ndarray]: List of consecutive sequences. + """ + data = data[:] + return np.split(data, np.where(np.diff(data) != stepsize)[0]+1) + + +def nan_helper(y: np.ndarray) -> Tuple: + """ + Identifies indices of NaN values in an array and provides a function to convert them to non-NaN indices. + + Args: + y (np.ndarray): Input array containing NaN values. + + Returns: + Tuple[np.ndarray, Union[np.ndarray, None]]: A tuple containing two elements: + - An array of boolean values indicating the positions of NaN values. + - A lambda function to convert NaN indices to non-NaN indices. + """ + return np.isnan(y), lambda z: z.nonzero()[0] + + +def interpol_all_nans(arr: np.ndarray) -> np.ndarray: + """ + Interpolates all NaN values in the given array. + + Args: + arr (np.ndarray): Input array containing NaN values. + + Returns: + np.ndarray: Array with NaN values replaced by interpolated values. + """ + y = np.transpose(arr) + nans, x = nan_helper(y) + y[nans]= np.interp(x(nans), x(~nans), y[~nans]) + arr = np.transpose(y) + return arr + + +def interpol_first_rows_nans(arr: np.ndarray) -> np.ndarray: + """ + Interpolates NaN values in the given array. + + Args: + arr (np.ndarray): Input array with NaN values. + + Returns: + np.ndarray: Array with interpolated NaN values. + """ + + y = np.transpose(arr) + + nans, x = nan_helper(y[0]) + y[0][nans]= np.interp(x(nans), x(~nans), y[0][~nans]) + nans, x = nan_helper(y[1]) + y[1][nans]= np.interp(x(nans), x(~nans), y[1][~nans]) + + arr = np.transpose(y) + + return arr + +def crop_and_flip( + rect: Tuple, + src: np.ndarray, + points: List[np.ndarray], + ref_index: Tuple[int, int] +) -> Tuple[np.ndarray, List[np.ndarray]]: + """ + Crop and flip the image based on the given rectangle and points. + + Args: + rect (Tuple): Rectangle coordinates (center, size, theta). + src (np.ndarray): Source image. + points (List[np.ndarray]): List of points. + ref_index (Tuple[int, int]): Reference indices for alignment. + + Returns: + Tuple[np.ndarray, List[np.ndarray]]: Cropped and flipped image, and shifted points. + """ + #Read out rect structures and convert + center, size, theta = rect + + center, size = tuple(map(int, center)), tuple(map(int, size)) + + #Get rotation matrix + M = cv.getRotationMatrix2D(center, theta, 1) + + #shift DLC points + x_diff = center[0] - size[0]//2 + y_diff = center[1] - size[1]//2 + + dlc_points_shifted = [] + + for i in points: + point=cv.transform(np.array([[[i[0], i[1]]]]),M)[0][0] + + point[0] -= x_diff + point[1] -= y_diff + + dlc_points_shifted.append(point) + + # Perform rotation on src image + dst = cv.warpAffine(src.astype('float32'), M, src.shape[:2]) + out = cv.getRectSubPix(dst, size, center) + + #check if flipped correctly, otherwise flip again + if dlc_points_shifted[ref_index[1]][0] >= dlc_points_shifted[ref_index[0]][0]: + rect = ((size[0]//2,size[0]//2),size,180) #should second value be size[1]? Is this relevant to the flip? 3/5/24 KKL + center, size, theta = rect + center, size = tuple(map(int, center)), tuple(map(int, size)) + + #Get rotation matrix + M = cv.getRotationMatrix2D(center, theta, 1) + + #shift DLC points + x_diff = center[0] - size[0]//2 + y_diff = center[1] - size[1]//2 + + points = dlc_points_shifted + dlc_points_shifted = [] + + for i in points: + point=cv.transform(np.array([[[i[0], i[1]]]]),M)[0][0] + + point[0] -= x_diff + point[1] -= y_diff + + dlc_points_shifted.append(point) + + # Perform rotation on src image + dst = cv.warpAffine(out.astype('float32'), M, out.shape[:2]) + out = cv.getRectSubPix(dst, size, center) + + return out, dlc_points_shifted + +def background( + path_to_file: str, + filename: str, + file_format: str = '.mp4', + num_frames: int = 1000, + save_background: bool = True +) -> np.ndarray: + """ + Compute background image from fixed camera. + + Args: + path_to_file (str): Path to the directory containing the video files. + filename (str): Name of the video file. + file_format (str, optional): Format of the video file. Defaults to '.mp4'. + num_frames (int, optional): Number of frames to use for background computation. Defaults to 1000. + + Returns: + np.ndarray: Background image. + """ + + capture = cv.VideoCapture(os.path.join(path_to_file,"videos",filename+file_format)) + + if not capture.isOpened(): + raise Exception("Unable to open video file: {0}".format(os.path.join(path_to_file,"videos",filename+file_format))) + + frame_count = int(capture.get(cv.CAP_PROP_FRAME_COUNT)) + ret, frame = capture.read() + + height, width, _ = frame.shape + frames = np.zeros((height,width,num_frames)) + + for i in tqdm.tqdm(range(num_frames), disable=not True, desc='Compute background image for video %s' %filename): + rand = np.random.choice(frame_count, replace=False) + capture.set(1,rand) + ret, frame = capture.read() + gray = cv.cvtColor(frame, cv.COLOR_BGR2GRAY) + frames[...,i] = gray + + logger.info('Finishing up!') + medFrame = np.median(frames,2) + background = median_filter(medFrame, (5,5)) + + if save_background: + np.save(os.path.join(path_to_file,"videos",filename+'-background.npy'),background) + + capture.release() + return background \ No newline at end of file diff --git a/src/vame/util/gif_pose_helper.py b/src/vame/util/gif_pose_helper.py index 735abda2..dd682421 100644 --- a/src/vame/util/gif_pose_helper.py +++ b/src/vame/util/gif_pose_helper.py @@ -17,157 +17,16 @@ import numpy as np import pandas as pd from vame.logging.logger import VameLogger +from vame.util.data_manipulation import ( + interpol_first_rows_nans, + crop_and_flip, + background +) + logger_config = VameLogger(__name__) logger = logger_config.logger - -def crop_and_flip(rect: tuple, src: np.ndarray, points: list, ref_index: list) -> tuple: - """ - Crop and flip an image based on a rectangle and reference points. - - Args: - rect (tuple): Tuple containing rectangle information (center, size, angle). - src (np.ndarray): Source image to crop and flip. - points (list): List of points to be aligned. - ref_index (list): Reference indices for alignment. - - Returns: - tuple: Cropped and flipped image, shifted points. - """ - #Read out rect structures and convert - center, size, theta = rect - center, size = tuple(map(int, center)), tuple(map(int, size)) - #Get rotation matrix - M = cv.getRotationMatrix2D(center, theta, 1) - - #shift DLC points - x_diff = center[0] - size[0]//2 - y_diff = center[1] - size[1]//2 - - dlc_points_shifted = [] - - for i in points: - point=cv.transform(np.array([[[i[0], i[1]]]]),M)[0][0] - - point[0] -= x_diff - point[1] -= y_diff - - dlc_points_shifted.append(point) - - # Perform rotation on src image - dst = cv.warpAffine(src.astype('float32'), M, src.shape[:2]) - out = cv.getRectSubPix(dst, size, center) - - #check if flipped correctly, otherwise flip again - if dlc_points_shifted[ref_index[1]][0] >= dlc_points_shifted[ref_index[0]][0]: - rect = ((size[0]//2,size[0]//2),size,180) - center, size, theta = rect - center, size = tuple(map(int, center)), tuple(map(int, size)) - #Get rotation matrix - M = cv.getRotationMatrix2D(center, theta, 1) - - - #shift DLC points - x_diff = center[0] - size[0]//2 - y_diff = center[1] - size[1]//2 - - points = dlc_points_shifted - dlc_points_shifted = [] - - for i in points: - point=cv.transform(np.array([[[i[0], i[1]]]]),M)[0][0] - - point[0] -= x_diff - point[1] -= y_diff - - dlc_points_shifted.append(point) - - # Perform rotation on src image - dst = cv.warpAffine(out.astype('float32'), M, out.shape[:2]) - out = cv.getRectSubPix(dst, size, center) - - return out, dlc_points_shifted - - -def background(path_to_file: str, filename: str, file_format: str = '.mp4', num_frames: int = 100) -> np.ndarray: - """ - Compute background image from fixed camera. - - Args: - path_to_file (str): Path to the directory containing the video files. - filename (str): Name of the video file. - file_format (str, optional): Format of the video file. Defaults to '.mp4'. - num_frames (int, optional): Number of frames to use for background computation. Defaults to 1000. - - Returns: - np.ndarray: Background image. - """ - - capture = cv.VideoCapture(os.path.join(path_to_file,"videos",filename+file_format)) - - if not capture.isOpened(): - raise Exception("Unable to open video file: {0}".format(os.path.join(path_to_file,"videos",filename+file_format))) - - frame_count = int(capture.get(cv.CAP_PROP_FRAME_COUNT)) - ret, frame = capture.read() - - height, width, _ = frame.shape - frames = np.zeros((height,width,num_frames)) - - for i in tqdm.tqdm(range(num_frames), disable=not True, desc='Compute background image for video %s' %filename): - rand = np.random.choice(frame_count, replace=False) - capture.set(1,rand) - ret, frame = capture.read() - gray = cv.cvtColor(frame, cv.COLOR_BGR2GRAY) - frames[...,i] = gray - - logger.info('Finishing up!') - medFrame = np.median(frames,2) - background = scipy.ndimage.median_filter(medFrame, (5,5)) - - np.save(os.path.join(path_to_file,"videos",filename+'-background.npy'),background) - - capture.release() - return background - - -def nan_helper(y: np.ndarray) -> tuple: - """ - Helper function to find indices of NaN values. - - Args: - y (np.ndarray): Input array. - - Returns: - tuple: Indices of NaN values. - """ - return np.isnan(y), lambda z: z.nonzero()[0] - - -def interpol(arr: np.ndarray) -> np.ndarray: - """ - Interpolates NaN values in the given array. - - Args: - arr (np.ndarray): Input array with NaN values. - - Returns: - np.ndarray: Array with interpolated NaN values. - """ - - y = np.transpose(arr) - - nans, x = nan_helper(y[0]) - y[0][nans]= np.interp(x(nans), x(~nans), y[0][~nans]) - nans, x = nan_helper(y[1]) - y[1][nans]= np.interp(x(nans), x(~nans), y[1][~nans]) - - arr = np.transpose(y) - - return arr - - def get_animal_frames( cfg: dict, filename: str, @@ -224,7 +83,7 @@ def get_animal_frames( bg = np.load(os.path.join(path_to_file,"videos",filename+'-background.npy')) except Exception: logger.info("Can't find background image... Calculate background image...") - bg = background(path_to_file,filename, file_format) + bg = background(path_to_file,filename, file_format, save_background=True) images = [] points = [] @@ -236,7 +95,7 @@ def get_animal_frames( for i in pose_list: - i = interpol(i) + i = interpol_first_rows_nans(i) capture = cv.VideoCapture(os.path.join(path_to_file,"videos",filename+file_format)) if not capture.isOpened(): diff --git a/src/vame/util/model_util.py b/src/vame/util/model_util.py new file mode 100644 index 00000000..37bc6863 --- /dev/null +++ b/src/vame/util/model_util.py @@ -0,0 +1,70 @@ + +import os +import yaml +import ruamel.yaml +from pathlib import Path +from typing import Tuple +import torch +from vame.logging.logger import VameLogger +from vame.model.rnn_model import RNN_VAE + + +logger_config = VameLogger(__name__) +logger = logger_config.logger + +def load_model(cfg: dict, model_name: str, fixed: bool = True) -> RNN_VAE: + """Load the VAME model. + + Args: + cfg (dict): Configuration dictionary. + model_name (str): Name of the model. + fixed (bool): Fixed or variable length sequences. + + Returns: + RNN_VAE: Loaded VAME model. + """ + # load Model + ZDIMS = cfg['zdims'] + FUTURE_DECODER = cfg['prediction_decoder'] + TEMPORAL_WINDOW = cfg['time_window']*2 + FUTURE_STEPS = cfg['prediction_steps'] + NUM_FEATURES = cfg['num_features'] + + if not fixed: + NUM_FEATURES = NUM_FEATURES - 2 + hidden_size_layer_1 = cfg['hidden_size_layer_1'] + hidden_size_layer_2 = cfg['hidden_size_layer_2'] + hidden_size_rec = cfg['hidden_size_rec'] + hidden_size_pred = cfg['hidden_size_pred'] + dropout_encoder = cfg['dropout_encoder'] + dropout_rec = cfg['dropout_rec'] + dropout_pred = cfg['dropout_pred'] + softplus = cfg['softplus'] + + + logger.info('Loading model... ') + + model = RNN_VAE( + TEMPORAL_WINDOW, + ZDIMS, + NUM_FEATURES, + FUTURE_DECODER, + FUTURE_STEPS, + hidden_size_layer_1, + hidden_size_layer_2, + hidden_size_rec, + hidden_size_pred, + dropout_encoder, + dropout_rec, + dropout_pred, + softplus + ) + if torch.cuda.is_available(): + model = model.cuda() + else: + model = model.cpu() + + model.load_state_dict(torch.load(os.path.join(cfg['project_path'],'model','best_model',model_name+'_'+cfg['Project']+'.pkl'))) + model.eval() + + return model diff --git a/tests/conftest.py b/tests/conftest.py index 125bdebd..f6e5826a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,25 +5,6 @@ import shutil -def pytest_collection_modifyitems(items): - """Modifies test items in place to ensure test modules run in a given order. - We are using this because these are integration tests and we need to run them in a specific order to avoid errors. - """ - MODULE_ORDER = [ - "test_initialize_project", - "test_util", - "test_model", - "test_analysis" - ] - module_mapping = {item: item.module.__name__ for item in items} - sorted_items = items.copy() - # Iteratively move tests of each module to the end of the test queue - for module in MODULE_ORDER: - sorted_items = [it for it in sorted_items if module_mapping[it] != module] + [ - it for it in sorted_items if module_mapping[it] == module - ] - items[:] = sorted_items - def init_project( project: str, videos: list, diff --git a/tests/test_analysis.py b/tests/test_analysis.py index 5ca3d850..c2073617 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -40,13 +40,11 @@ def test_motif_videos_files_exists(setup_project_and_train_model): assert len(list(save_base_path.glob("*.mp4"))) > 0 assert len(list(save_base_path.glob("*.mp4"))) <= n_cluster -@pytest.mark.parametrize("show_umap", [True, False]) -def test_community_files_exists(setup_project_and_train_model, show_umap): + +def test_community_files_exists(setup_project_and_train_model): # Check if the files are created vame.community( setup_project_and_train_model['config_path'], - show_umap=show_umap, - save_umap_figure=show_umap, cut_tree=2, cohort=False ) @@ -66,16 +64,12 @@ def test_community_files_exists(setup_project_and_train_model, show_umap): assert community_label_path.exists() assert hierarchy_path.exists() - if show_umap: - umap_save_path = save_base_path / f'{file}_umap.png' - assert umap_save_path.exists() def test_cohort_community_files_exists(setup_project_and_train_model): # Check if the files are created vame.community( setup_project_and_train_model['config_path'], - show_umap=False, cut_tree=2, cohort=True, save_logs=True @@ -145,13 +139,12 @@ def test_gif_frames_files_exists(setup_project_and_evaluate_model, label): with patch("builtins.input", return_value="yes"): vame.pose_segmentation(setup_project_and_evaluate_model["config_path"]) - def mock_background(path_to_file=None, filename=None, file_format=None, num_frames=None): + def mock_background(path_to_file=None, filename=None, file_format=None, num_frames=None, save_background=True): num_frames = 100 - return background(path_to_file, filename, file_format, num_frames) + return background(path_to_file, filename, file_format, num_frames, save_background) vame.community( setup_project_and_evaluate_model["config_path"], - show_umap=False, cut_tree=2, cohort=False, save_logs=False