Skip to content

Commit

Permalink
Merge pull request #61 from EthoML/main
Browse files Browse the repository at this point in the history
Mege main
  • Loading branch information
vinicvaz authored Jul 15, 2024
2 parents 6f00c7f + 1cee939 commit 51ecf3e
Show file tree
Hide file tree
Showing 25 changed files with 407 additions and 781 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
<a href="https://codecov.io/gh/EthoML/VAME" >
<img src="https://codecov.io/gh/EthoML/VAME/graph/badge.svg?token=J1CUXB4N0E"/>
</a>
<a href="https://pypi.org/project/vame-py">
<img src="https://img.shields.io/pypi/v/vame-py?color=%231BA331&label=PyPI&logo=python&logoColor=%23F7F991%20">
</a>
</p>

🌟 Welcome to EthoML/VAME (Variational Animal Motion Encoding), an open-source machine learning tool for behavioral segmentation and analyses.
Expand Down
2 changes: 1 addition & 1 deletion examples/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@
"outputs": [],
"source": [
"# # OPTIONAL: Create behavioural hierarchies via community detection\n",
"vame.community(config, show_umap=False, cut_tree=2, cohort=True)"
"vame.community(config, cut_tree=2, cohort=False)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion examples/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
# vame.motif_videos(config, videoType='.mp4')

# # OPTIONAL: Create behavioural hierarchies via community detection
# vame.community(config, show_umap=False, cut_tree=2)
# vame.community(config, cut_tree=2)

# # OPTIONAL: Create community videos to get insights about behavior on a hierarchical scale
# vame.community_videos(config)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vame-py"
version = '0.2.0'
version = '0.3.0'
dynamic = ["dependencies"]
description = "Variational Animal Motion Embedding."
authors = [
Expand Down
1 change: 1 addition & 0 deletions src/vame/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@
from vame.analysis import gif
from vame.util.csv_to_npy import csv_to_numpy
from vame.util.align_egocentrical import egocentric_alignment
from vame.util import model_util
from vame.util import auxiliary

2 changes: 1 addition & 1 deletion src/vame/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vame.analysis.pose_segmentation import pose_segmentation
from vame.analysis.videowriter import motif_videos, community_videos
from vame.analysis.community_analysis import community
from vame.analysis.umap_visualization import visualization
from vame.analysis.umap import visualization
from vame.analysis.generative_functions import generative_model
from vame.analysis.gif_creator import gif

98 changes: 2 additions & 96 deletions src/vame/analysis/community_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
"""

import os
import umap
import scipy
import pickle
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from vame.util.auxiliary import read_config
from vame.analysis.tree_hierarchy import graph_to_tree, draw_tree, traverse_tree_cutline
from vame.util.data_manipulation import consecutive
from typing import List, Tuple
from vame.schemas.states import save_state, CommunityFunctionSchema
from vame.logging.logger import VameLogger
Expand Down Expand Up @@ -87,21 +87,6 @@ def get_transition_matrix(adjacency_matrix: np.ndarray, threshold: float = 0.0)
return transition_matrix


def consecutive(data: np.ndarray, stepsize: int = 1) -> List[np.ndarray]:
"""Identifies location of missing motif finding consecutive elements in an array and returns array(s) at the split.
Args:
data (np.ndarray): Input array.
stepsize (int, optional): Step size. Defaults to 1.
Returns:
List[np.ndarray]: List of arrays containing consecutive elements.
"""
data = data[:]
return np.split(data, np.where(np.diff(data) != stepsize)[0]+1)

# New find_zero_labels 8/8/2022 KL


def find_zero_labels(motif_usage: Tuple[np.ndarray, np.ndarray], n_cluster: int) -> np.ndarray:
"""Find zero labels in motif usage and fill them.
Expand Down Expand Up @@ -311,7 +296,6 @@ def create_community_bag(
return communities_all, trees

def create_cohort_community_bag(
files: List[str],
labels: List[np.ndarray],
trans_mat_full: np.ndarray,
cut_tree: int,
Expand All @@ -321,7 +305,6 @@ def create_cohort_community_bag(
(markov chain to tree -> community detection)
Args:
files (List[str]): List of files paths (deprecated).
labels (List[np.ndarray]): List of label arrays.
trans_mat_full (np.ndarray): Full transition matrix.
cut_tree (int): Cut line for tree.
Expand Down Expand Up @@ -431,84 +414,18 @@ def get_cohort_community_labels(
return community_labels_all


def umap_embedding(cfg: dict, file: str, model_name: str, n_cluster: int, parametrization: str) -> np.ndarray:
"""Perform UMAP embedding for given file and parameters.
Args:
cfg (dict): Configuration parameters.
file (str): File path.
model_name (str): Model name.
n_cluster (int): Number of clusters.
parametrization (str): parametrization.
Returns:
np.ndarray: UMAP embedding.
"""
reducer = umap.UMAP(n_components=2, min_dist=cfg['min_dist'], n_neighbors=cfg['n_neighbors'],
random_state=cfg['random_state'])

logger.info("UMAP calculation for file %s" %file)

folder = os.path.join(cfg['project_path'],"results",file,model_name, parametrization +'-'+str(n_cluster),"")
latent_vector = np.load(os.path.join(folder,'latent_vector_'+file+'.npy'))

num_points = cfg['num_points']
if num_points > latent_vector.shape[0]:
num_points = latent_vector.shape[0]
logger.info("Embedding %d data points.." %num_points)

embed = reducer.fit_transform(latent_vector[:num_points,:])

return embed


def umap_vis(cfg: dict, file: str, embed: np.ndarray, community_labels_all: np.ndarray, save_path: str | None) -> None:
"""Create plotly visualizaton of UMAP embedding.
Args:
cfg (dict): Configuration parameters.
file (str): File path.
embed (np.ndarray): UMAP embedding.
community_labels_all (np.ndarray): Community labels.
save_path: Path to save the plot. If None it will not save the plot.
Returns:
None
"""
num_points = cfg['num_points']
community_labels_all = np.asarray(community_labels_all)
if num_points > community_labels_all.shape[0]:
num_points = community_labels_all.shape[0]
logger.info("Embedding %d data points.." %num_points)

num = np.unique(community_labels_all)

fig = plt.figure(1)
plt.scatter(embed[:,0], embed[:,1], c=community_labels_all[:num_points], cmap='Spectral', s=2, alpha=1)
plt.colorbar(boundaries=np.arange(np.max(num)+2)-0.5).set_ticks(np.arange(np.max(num)+1))
plt.gca().set_aspect('equal', 'datalim')
plt.grid(False)

if save_path is not None:
plt.savefig(save_path)
return
plt.show()

@save_state(model=CommunityFunctionSchema)
def community(
config: str,
cohort: bool = True,
show_umap: bool = False,
cut_tree: int | None = None,
save_umap_figure: bool = True,
save_logs: bool = False
) -> None:
"""Perform community analysis.
Args:
config (str): Path to the configuration file.
cohort (bool, optional): Flag indicating cohort analysis. Defaults to True.
show_umap (bool, optional): Flag indicating weather to show UMAP visualization. Defaults to False.
cut_tree (int, optional): Cut line for tree. Defaults to None.
Returns:
Expand Down Expand Up @@ -551,7 +468,7 @@ def community(
augmented_label, zero_motifs = augment_motif_timeseries(labels, n_cluster)
_, trans_mat_full,_ = get_adjacency_matrix(augmented_label, n_cluster=n_cluster)
_, usage_full = np.unique(augmented_label, return_counts=True)
communities_all, trees = create_cohort_community_bag(files, labels, trans_mat_full, cut_tree, n_cluster)
communities_all, trees = create_cohort_community_bag(labels, trans_mat_full, cut_tree, n_cluster)

community_labels_all = get_cohort_community_labels(files, labels, communities_all)
# community_bag = traverse_tree_cutline(trees, cutline=cut_tree)
Expand All @@ -567,11 +484,6 @@ def community(
with open(os.path.join(cfg['project_path'],"hierarchy"+".pkl"), "wb") as fp: #Pickling
pickle.dump(communities_all, fp)

if show_umap:
embed = umap_embedding(cfg, files, model_name, n_cluster, parametrization)
# TODO fix umap vis for cohort and add save path
umap_vis(cfg, files, embed, community_labels_all)

# Work in Progress
elif not cohort:
labels = get_labels(cfg, files, model_name, n_cluster, parametrization)
Expand All @@ -590,12 +502,6 @@ def community(
with open(os.path.join(path_to_file,"community","hierarchy"+file+".pkl"), "wb") as fp: #Pickling
pickle.dump(communities_all[idx], fp)

if show_umap:
embed = umap_embedding(cfg, file, model_name, n_cluster, parametrization)
umap_save_path = None
if save_umap_figure:
umap_save_path = os.path.join(path_to_file, "community", file + "_umap.png")
umap_vis(cfg, files, embed, community_labels_all[idx], save_path=umap_save_path)
except Exception as e:
logger.exception(f"Error in community_analysis: {e}")
raise e
Expand Down
47 changes: 3 additions & 44 deletions src/vame/analysis/generative_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from sklearn.mixture import GaussianMixture
from vame.schemas.states import GenerativeModelFunctionSchema, save_state
from vame.util.auxiliary import read_config
from vame.model.rnn_model import RNN_VAE
from vame.logging.logger import VameLogger
from vame.util.model_util import load_model


logger_config = VameLogger(__name__)
logger = logger_config.logger
Expand Down Expand Up @@ -189,48 +190,6 @@ def visualize_cluster_center(cfg: dict, model: torch.nn.Module, cluster_center:
return fig


def load_model(cfg: dict, model_name: str) -> torch.nn.Module:
"""Load PyTorch model.
Args:
cfg (dict): Configuration dictionary.
model_name (str): Name of the model.
Returns:
torch.nn.Module: Loaded PyTorch model.
"""
ZDIMS = cfg['zdims']
FUTURE_DECODER = cfg['prediction_decoder']
TEMPORAL_WINDOW = cfg['time_window']*2
FUTURE_STEPS = cfg['prediction_steps']

NUM_FEATURES = cfg['num_features']
NUM_FEATURES = NUM_FEATURES - 2

hidden_size_layer_1 = cfg['hidden_size_layer_1']
hidden_size_layer_2 = cfg['hidden_size_layer_2']
hidden_size_rec = cfg['hidden_size_rec']
hidden_size_pred = cfg['hidden_size_pred']
dropout_encoder = cfg['dropout_encoder']
dropout_rec = cfg['dropout_rec']
dropout_pred = cfg['dropout_pred']
softplus = cfg['softplus']

logger.info('Loading model... ')

model = RNN_VAE(TEMPORAL_WINDOW,ZDIMS,NUM_FEATURES,FUTURE_DECODER,FUTURE_STEPS, hidden_size_layer_1,
hidden_size_layer_2, hidden_size_rec, hidden_size_pred, dropout_encoder,
dropout_rec, dropout_pred, softplus)
if torch.cuda.is_available():
model = model.cuda()
else:
model = model.cpu()

model.load_state_dict(torch.load(os.path.join(cfg['project_path'],'model','best_model',model_name+'_'+cfg['Project']+'.pkl')))
model.eval()

return model

@save_state(model=GenerativeModelFunctionSchema)
def generative_model(config: str, mode: str = "sampling", save_logs: bool = False) -> plt.Figure:
"""Generative model.
Expand Down Expand Up @@ -276,7 +235,7 @@ def generative_model(config: str, mode: str = "sampling", save_logs: bool = Fals
files.append(all_flag)


model = load_model(cfg, model_name)
model = load_model(cfg, model_name, fixed=False)

for file in files:
path_to_file=os.path.join(cfg['project_path'],"results",file,model_name, parametrization + '-' +str(n_cluster),"")
Expand Down
67 changes: 2 additions & 65 deletions src/vame/analysis/pose_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,69 +16,20 @@
import numpy as np
from pathlib import Path
from typing import List, Tuple

from vame.util.data_manipulation import consecutive
from hmmlearn import hmm
from sklearn.cluster import KMeans
from vame.schemas.states import save_state, PoseSegmentationFunctionSchema
from vame.logging.logger import VameLogger, TqdmToLogger
from vame.util.auxiliary import read_config
from vame.model.rnn_model import RNN_VAE
from vame.util.model_util import load_model


logger_config = VameLogger(__name__)
logger = logger_config.logger


def load_model(cfg: dict, model_name: str, fixed: bool) -> RNN_VAE:
"""Load the VAME model.
Args:
cfg (dict): Configuration dictionary.
model_name (str): Name of the model.
fixed (bool): Fixed or variable length sequences.
Returns:
RNN_VAE: Loaded VAME model.
"""
use_gpu = torch.cuda.is_available()
if use_gpu:
pass
else:
torch.device("cpu")

# load Model
ZDIMS = cfg['zdims']
FUTURE_DECODER = cfg['prediction_decoder']
TEMPORAL_WINDOW = cfg['time_window']*2
FUTURE_STEPS = cfg['prediction_steps']
NUM_FEATURES = cfg['num_features']
if not fixed:
NUM_FEATURES = NUM_FEATURES - 2
hidden_size_layer_1 = cfg['hidden_size_layer_1']
hidden_size_layer_2 = cfg['hidden_size_layer_2']
hidden_size_rec = cfg['hidden_size_rec']
hidden_size_pred = cfg['hidden_size_pred']
dropout_encoder = cfg['dropout_encoder']
dropout_rec = cfg['dropout_rec']
dropout_pred = cfg['dropout_pred']
softplus = cfg['softplus']


if use_gpu:
model = RNN_VAE(TEMPORAL_WINDOW,ZDIMS,NUM_FEATURES,FUTURE_DECODER,FUTURE_STEPS, hidden_size_layer_1,
hidden_size_layer_2, hidden_size_rec, hidden_size_pred, dropout_encoder,
dropout_rec, dropout_pred, softplus).cuda()
else:
model = RNN_VAE(TEMPORAL_WINDOW,ZDIMS,NUM_FEATURES,FUTURE_DECODER,FUTURE_STEPS, hidden_size_layer_1,
hidden_size_layer_2, hidden_size_rec, hidden_size_pred, dropout_encoder,
dropout_rec, dropout_pred, softplus).to()

model.load_state_dict(torch.load(os.path.join(cfg['project_path'],'model','best_model',model_name+'_'+cfg['Project']+'.pkl')))
model.eval()

return model


def embedd_latent_vectors(cfg: dict, files: List[str], model: RNN_VAE, fixed: bool, tqdm_stream: TqdmToLogger | None) -> List[np.ndarray]:
"""Embed latent vectors for the given files using the VAME model.
Expand Down Expand Up @@ -128,20 +79,6 @@ def embedd_latent_vectors(cfg: dict, files: List[str], model: RNN_VAE, fixed: bo
return latent_vector_files


def consecutive(data: np.ndarray, stepsize: int = 1) -> List[np.ndarray]:
"""Find consecutive sequences in the data array.
Args:
data (np.ndarray): Input array.
stepsize (int, optional): Step size. Defaults to 1.
Returns:
List[np.ndarray]: List of consecutive sequences.
"""
data = data[:]
return np.split(data, np.where(np.diff(data) != stepsize)[0]+1)


def get_motif_usage(label: np.ndarray) -> np.ndarray:
"""Compute motif usage from the label array.
Expand Down
Loading

0 comments on commit 51ecf3e

Please sign in to comment.