Skip to content

Commit

Permalink
Merge pull request #69 from EthoML/main
Browse files Browse the repository at this point in the history
docs
  • Loading branch information
vinicvaz authored Jul 29, 2024
2 parents 32148d3 + 5247b82 commit d8b4a52
Show file tree
Hide file tree
Showing 22 changed files with 592 additions and 302 deletions.
14 changes: 7 additions & 7 deletions examples/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
"outputs": [],
"source": [
"# # OPTIONAL: Create motif videos to get insights about the fine grained poses\n",
"vame.motif_videos(config, videoType='.mp4')"
"vame.motif_videos(config, videoType='.mp4', parametrization='hmm')"
]
},
{
Expand All @@ -142,7 +142,7 @@
"outputs": [],
"source": [
"# # OPTIONAL: Create behavioural hierarchies via community detection\n",
"vame.community(config, cut_tree=2, cohort=False)"
"vame.community(config, cut_tree=2, cohort=True, parametrization='hmm')"
]
},
{
Expand All @@ -153,7 +153,7 @@
"outputs": [],
"source": [
"# # OPTIONAL: Create community videos to get insights about behavior on a hierarchical scale\n",
"vame.community_videos(config)"
"vame.community_videos(config, parametrization='hmm')"
]
},
{
Expand All @@ -164,7 +164,7 @@
"outputs": [],
"source": [
"# # OPTIONAL: Down projection of latent vectors and visualization via UMAP\n",
"fig = vame.visualization(config, label=None) #options: label: None, \"motif\", \"community\""
"fig = vame.visualization(config, label=None, parametrization='hmm') #options: label: None, \"motif\", \"community\""
]
},
{
Expand All @@ -177,7 +177,7 @@
"# # OPTIONAL: Use the generative model (reconstruction decoder) to sample from\n",
"# # the learned data distribution, reconstruct random real samples or visualize\n",
"# # the cluster center for validation\n",
"vame.generative_model(config, mode=\"centers\") #options: mode: \"sampling\", \"reconstruction\", \"centers\", \"motifs\""
"vame.generative_model(config, mode=\"sampling\", parametrization='hmm') #options: mode: \"sampling\", \"reconstruction\", \"centers\", \"motifs\""
]
},
{
Expand All @@ -192,8 +192,8 @@
"# and have something cool to show around ;)\n",
"# Note: This function is currently very slow. Once the frames are saved you can create a video\n",
"# or gif via e.g. ImageJ or other tools\n",
"vame.gif(config, pose_ref_index=[0,5], subtract_background=True, start=None,\n",
" length=500, max_lag=30, label='community', file_format='.mp4', crop_size=(300,300))\n"
"vame.gif(config, parametrization='hmm', pose_ref_index=[0,5], subtract_background=True, start=None,\n",
" length=30, max_lag=30, label='community', file_format='.mp4', crop_size=(300,300))\n"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vame-py"
version = '0.3.0'
version = '0.5.0'
dynamic = ["dependencies"]
description = "Variational Animal Motion Embedding."
authors = [
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ h5py>=3.11.0
pydantic==2.7.4
imageio==2.34.1
imageio-ffmpeg==0.5.1
pynwb==2.8.1
2 changes: 1 addition & 1 deletion src/vame/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from vame.analysis import visualization
from vame.analysis import generative_model
from vame.analysis import gif
from vame.util.csv_to_npy import csv_to_numpy
from vame.util.csv_to_npy import pose_to_numpy
from vame.util.align_egocentrical import egocentric_alignment
from vame.util import model_util
from vame.util import auxiliary
Expand Down
20 changes: 14 additions & 6 deletions src/vame/analysis/community_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import List, Tuple
from vame.schemas.states import save_state, CommunityFunctionSchema
from vame.logging.logger import VameLogger
from vame.schemas.project import Parametrizations


logger_config = VameLogger(__name__)
Expand Down Expand Up @@ -417,6 +418,7 @@ def get_cohort_community_labels(
@save_state(model=CommunityFunctionSchema)
def community(
config: str,
parametrization: Parametrizations,
cohort: bool = True,
cut_tree: int | None = None,
save_logs: bool = False
Expand All @@ -434,12 +436,13 @@ def community(
try:
config_file = Path(config).resolve()
cfg = read_config(config_file)
parametrizations = cfg['parametrizations']

if save_logs:
log_path = Path(cfg['project_path']) / 'logs' / 'community.log'
logger_config.add_file_handler(log_path)
model_name = cfg['model_name']
n_cluster = cfg['n_cluster']
parametrization = cfg['parametrization']

files = []
if cfg['all_data'] == 'No':
Expand All @@ -464,6 +467,11 @@ def community(
files.append(all_flag)

if cohort:
path_to_file = Path(os.path.join(cfg['project_path'], "results", 'community_cohort', parametrization + '-'+str(n_cluster)))

if not path_to_file.exists():
path_to_file.mkdir(parents=True, exist_ok=True)

labels = get_community_label(cfg, files, model_name, n_cluster, parametrization)
augmented_label, zero_motifs = augment_motif_timeseries(labels, n_cluster)
_, trans_mat_full,_ = get_adjacency_matrix(augmented_label, n_cluster=n_cluster)
Expand All @@ -476,12 +484,12 @@ def community(
# convert communities_all to dtype object numpy array because communities_all is an inhomogeneous list
communities_all = np.array(communities_all, dtype=object)

np.save(os.path.join(cfg['project_path'],"cohort_transition_matrix"+'.npy'),trans_mat_full)
np.save(os.path.join(cfg['project_path'],"cohort_community_label"+'.npy'), community_labels_all)
np.save(os.path.join(cfg['project_path'],"cohort_" + parametrization + "_label"+'.npy'), labels)
np.save(os.path.join(cfg['project_path'],"cohort_community_bag"+'.npy'), communities_all)
np.save(os.path.join(path_to_file,"cohort_transition_matrix"+'.npy'),trans_mat_full)
np.save(os.path.join(path_to_file,"cohort_community_label"+'.npy'), community_labels_all)
np.save(os.path.join(path_to_file,"cohort_" + parametrization + "_label"+'.npy'), labels)
np.save(os.path.join(path_to_file,"cohort_community_bag"+'.npy'), communities_all)

with open(os.path.join(cfg['project_path'],"hierarchy"+".pkl"), "wb") as fp: #Pickling
with open(os.path.join(path_to_file, "hierarchy"+".pkl"), "wb") as fp: #Pickling
pickle.dump(communities_all, fp)

# Work in Progress
Expand Down
10 changes: 7 additions & 3 deletions src/vame/analysis/generative_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from vame.schemas.states import GenerativeModelFunctionSchema, save_state
from vame.util.auxiliary import read_config
from vame.logging.logger import VameLogger
from typing import Dict
from vame.util.model_util import load_model
from vame.schemas.project import Parametrizations


logger_config = VameLogger(__name__)
Expand Down Expand Up @@ -191,15 +193,15 @@ def visualize_cluster_center(cfg: dict, model: torch.nn.Module, cluster_center:


@save_state(model=GenerativeModelFunctionSchema)
def generative_model(config: str, mode: str = "sampling", save_logs: bool = False) -> plt.Figure:
def generative_model(config: str, parametrization: Parametrizations, mode: str = "sampling", save_logs: bool = False) -> Dict[str, plt.Figure]:
"""Generative model.
Args:
config (str): Path to the configuration file.
mode (str, optional): Mode for generating samples. Defaults to "sampling".
Returns:
plt.Figure: Plot of generated samples.
Dict[str, plt.Figure]: Plots of generated samples for each parametrization.
"""
try:
config_file = Path(config).resolve()
Expand All @@ -210,7 +212,7 @@ def generative_model(config: str, mode: str = "sampling", save_logs: bool = Fals
logger.info(f'Running generative model with mode {mode}...')
model_name = cfg['model_name']
n_cluster = cfg['n_cluster']
parametrization = cfg['parametrization']
parametrizations = cfg['parametrizations']

files = []
if cfg['all_data'] == 'No':
Expand Down Expand Up @@ -249,6 +251,8 @@ def generative_model(config: str, mode: str = "sampling", save_logs: bool = Fals
return random_reconstruction_samples(cfg, model, latent_vector)

if mode == "centers":
if parametrization != 'kmeans':
raise ValueError(f"Parametrization {parametrization} not supported for cluster center visualization.")
cluster_center = np.load(os.path.join(path_to_file,'cluster_center_'+file+'.npy'))
return visualize_cluster_center(cfg, model, cluster_center)

Expand Down
10 changes: 7 additions & 3 deletions src/vame/analysis/gif_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vame.util.gif_pose_helper import get_animal_frames
from typing import List, Tuple
from vame.logging.logger import VameLogger
from vame.schemas.project import Parametrizations



Expand Down Expand Up @@ -91,6 +92,7 @@ def create_video(
def gif(
config: str,
pose_ref_index: int,
parametrization: Parametrizations,
subtract_background: bool = True,
start: int | None = None,
length: int = 500,
Expand Down Expand Up @@ -120,7 +122,9 @@ def gif(
cfg = read_config(config_file)
model_name = cfg['model_name']
n_cluster = cfg['n_cluster']
param = cfg['parametrization']

if parametrization not in cfg['parametrizations']:
raise ValueError("Parametrization not found in config")

files = []
if cfg['all_data'] == 'No':
Expand All @@ -146,7 +150,7 @@ def gif(


for file in files:
path_to_file=os.path.join(cfg['project_path'],"results",file,model_name,param+'-'+str(n_cluster),"")
path_to_file=os.path.join(cfg['project_path'],"results",file,model_name, parametrization+'-'+str(n_cluster),"")
if not os.path.exists(os.path.join(path_to_file,"gif_frames")):
os.mkdir(os.path.join(path_to_file,"gif_frames"))

Expand All @@ -173,7 +177,7 @@ def gif(
np.save(os.path.join(path_to_file,"community","umap_embedding_"+file+'.npy'), embed)

if label == "motif":
umap_label = np.load(os.path.join(path_to_file,str(n_cluster)+"_" + param + "_label_"+file+'.npy'))
umap_label = np.load(os.path.join(path_to_file,str(n_cluster)+"_" + parametrization + "_label_"+file+'.npy'))
elif label == "community":
umap_label = np.load(os.path.join(path_to_file,"community","community_label_"+file+'.npy'))
elif label is None:
Expand Down
Loading

0 comments on commit d8b4a52

Please sign in to comment.