Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #66

Merged
merged 23 commits into from
Jul 24, 2024
Merged

Dev #66

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
9abb877
using enum parametrization in schema
vinicvaz Jul 9, 2024
1c4195c
pose segmentation using param array
vinicvaz Jul 9, 2024
ba91a84
communiy analysis using param array
vinicvaz Jul 9, 2024
4f334fa
videowriter using param array
vinicvaz Jul 9, 2024
49a5db3
generative function accepting parametrization as arg
vinicvaz Jul 9, 2024
945732c
gif creator accepting parametrization as arg
vinicvaz Jul 9, 2024
97e9f69
umap visualization automatically generating vis for each parametrization
vinicvaz Jul 9, 2024
12056a5
improve tests to use multiple parametrizations
vinicvaz Jul 9, 2024
34bd482
motif videos using parametrization enum
vinicvaz Jul 17, 2024
37b42a8
Merge branch 'main' of https://github.com/EthoML/VAME into feat/run-a…
vinicvaz Jul 17, 2024
f4fddfa
add param argument to functions
vinicvaz Jul 17, 2024
bdb3b62
saving cohort in specific folder
vinicvaz Jul 17, 2024
450b812
Merge pull request #58 from EthoML/feat/run-all-parametrization
vinicvaz Jul 17, 2024
ad5ea1b
save logs pose seg
vinicvaz Jul 19, 2024
7e18876
Merge branch 'dev' of https://github.com/EthoML/VAME into dev
vinicvaz Jul 19, 2024
36bef8c
Fix pose seg cov
vinicvaz Jul 22, 2024
515fcd4
add output file type to videowriter
vinicvaz Jul 22, 2024
dd54c38
add tests to different types
vinicvaz Jul 22, 2024
30cea9e
remove duplicated logic
vinicvaz Jul 22, 2024
daa81ee
keep possible deprecated function commented
vinicvaz Jul 22, 2024
1dd35de
test already initialized project
vinicvaz Jul 22, 2024
48763c3
fix init project from folder
vinicvaz Jul 23, 2024
482d226
fix test init project
vinicvaz Jul 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions examples/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
"outputs": [],
"source": [
"# # OPTIONAL: Create motif videos to get insights about the fine grained poses\n",
"vame.motif_videos(config, videoType='.mp4')"
"vame.motif_videos(config, videoType='.mp4', parametrization='hmm')"
]
},
{
Expand All @@ -142,7 +142,7 @@
"outputs": [],
"source": [
"# # OPTIONAL: Create behavioural hierarchies via community detection\n",
"vame.community(config, cut_tree=2, cohort=False)"
"vame.community(config, cut_tree=2, cohort=True, parametrization='hmm')"
]
},
{
Expand All @@ -153,7 +153,7 @@
"outputs": [],
"source": [
"# # OPTIONAL: Create community videos to get insights about behavior on a hierarchical scale\n",
"vame.community_videos(config)"
"vame.community_videos(config, parametrization='hmm')"
]
},
{
Expand All @@ -164,7 +164,7 @@
"outputs": [],
"source": [
"# # OPTIONAL: Down projection of latent vectors and visualization via UMAP\n",
"fig = vame.visualization(config, label=None) #options: label: None, \"motif\", \"community\""
"fig = vame.visualization(config, label=None, parametrization='hmm') #options: label: None, \"motif\", \"community\""
]
},
{
Expand All @@ -177,7 +177,7 @@
"# # OPTIONAL: Use the generative model (reconstruction decoder) to sample from\n",
"# # the learned data distribution, reconstruct random real samples or visualize\n",
"# # the cluster center for validation\n",
"vame.generative_model(config, mode=\"centers\") #options: mode: \"sampling\", \"reconstruction\", \"centers\", \"motifs\""
"vame.generative_model(config, mode=\"sampling\", parametrization='hmm') #options: mode: \"sampling\", \"reconstruction\", \"centers\", \"motifs\""
]
},
{
Expand All @@ -192,8 +192,8 @@
"# and have something cool to show around ;)\n",
"# Note: This function is currently very slow. Once the frames are saved you can create a video\n",
"# or gif via e.g. ImageJ or other tools\n",
"vame.gif(config, pose_ref_index=[0,5], subtract_background=True, start=None,\n",
" length=500, max_lag=30, label='community', file_format='.mp4', crop_size=(300,300))\n"
"vame.gif(config, parametrization='hmm', pose_ref_index=[0,5], subtract_background=True, start=None,\n",
" length=30, max_lag=30, label='community', file_format='.mp4', crop_size=(300,300))\n"
]
},
{
Expand Down
20 changes: 14 additions & 6 deletions src/vame/analysis/community_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import List, Tuple
from vame.schemas.states import save_state, CommunityFunctionSchema
from vame.logging.logger import VameLogger
from vame.schemas.project import Parametrizations


logger_config = VameLogger(__name__)
Expand Down Expand Up @@ -417,6 +418,7 @@ def get_cohort_community_labels(
@save_state(model=CommunityFunctionSchema)
def community(
config: str,
parametrization: Parametrizations,
cohort: bool = True,
cut_tree: int | None = None,
save_logs: bool = False
Expand All @@ -434,12 +436,13 @@ def community(
try:
config_file = Path(config).resolve()
cfg = read_config(config_file)
parametrizations = cfg['parametrizations']

if save_logs:
log_path = Path(cfg['project_path']) / 'logs' / 'community.log'
logger_config.add_file_handler(log_path)
model_name = cfg['model_name']
n_cluster = cfg['n_cluster']
parametrization = cfg['parametrization']

files = []
if cfg['all_data'] == 'No':
Expand All @@ -464,6 +467,11 @@ def community(
files.append(all_flag)

if cohort:
path_to_file = Path(os.path.join(cfg['project_path'], "results", 'community_cohort', parametrization + '-'+str(n_cluster)))

if not path_to_file.exists():
path_to_file.mkdir(parents=True, exist_ok=True)

labels = get_community_label(cfg, files, model_name, n_cluster, parametrization)
augmented_label, zero_motifs = augment_motif_timeseries(labels, n_cluster)
_, trans_mat_full,_ = get_adjacency_matrix(augmented_label, n_cluster=n_cluster)
Expand All @@ -476,12 +484,12 @@ def community(
# convert communities_all to dtype object numpy array because communities_all is an inhomogeneous list
communities_all = np.array(communities_all, dtype=object)

np.save(os.path.join(cfg['project_path'],"cohort_transition_matrix"+'.npy'),trans_mat_full)
np.save(os.path.join(cfg['project_path'],"cohort_community_label"+'.npy'), community_labels_all)
np.save(os.path.join(cfg['project_path'],"cohort_" + parametrization + "_label"+'.npy'), labels)
np.save(os.path.join(cfg['project_path'],"cohort_community_bag"+'.npy'), communities_all)
np.save(os.path.join(path_to_file,"cohort_transition_matrix"+'.npy'),trans_mat_full)
np.save(os.path.join(path_to_file,"cohort_community_label"+'.npy'), community_labels_all)
np.save(os.path.join(path_to_file,"cohort_" + parametrization + "_label"+'.npy'), labels)
np.save(os.path.join(path_to_file,"cohort_community_bag"+'.npy'), communities_all)

with open(os.path.join(cfg['project_path'],"hierarchy"+".pkl"), "wb") as fp: #Pickling
with open(os.path.join(path_to_file, "hierarchy"+".pkl"), "wb") as fp: #Pickling
pickle.dump(communities_all, fp)

# Work in Progress
Expand Down
10 changes: 7 additions & 3 deletions src/vame/analysis/generative_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from vame.schemas.states import GenerativeModelFunctionSchema, save_state
from vame.util.auxiliary import read_config
from vame.logging.logger import VameLogger
from typing import Dict
from vame.util.model_util import load_model
from vame.schemas.project import Parametrizations


logger_config = VameLogger(__name__)
Expand Down Expand Up @@ -191,15 +193,15 @@ def visualize_cluster_center(cfg: dict, model: torch.nn.Module, cluster_center:


@save_state(model=GenerativeModelFunctionSchema)
def generative_model(config: str, mode: str = "sampling", save_logs: bool = False) -> plt.Figure:
def generative_model(config: str, parametrization: Parametrizations, mode: str = "sampling", save_logs: bool = False) -> Dict[str, plt.Figure]:
"""Generative model.

Args:
config (str): Path to the configuration file.
mode (str, optional): Mode for generating samples. Defaults to "sampling".

Returns:
plt.Figure: Plot of generated samples.
Dict[str, plt.Figure]: Plots of generated samples for each parametrization.
"""
try:
config_file = Path(config).resolve()
Expand All @@ -210,7 +212,7 @@ def generative_model(config: str, mode: str = "sampling", save_logs: bool = Fals
logger.info(f'Running generative model with mode {mode}...')
model_name = cfg['model_name']
n_cluster = cfg['n_cluster']
parametrization = cfg['parametrization']
parametrizations = cfg['parametrizations']

files = []
if cfg['all_data'] == 'No':
Expand Down Expand Up @@ -249,6 +251,8 @@ def generative_model(config: str, mode: str = "sampling", save_logs: bool = Fals
return random_reconstruction_samples(cfg, model, latent_vector)

if mode == "centers":
if parametrization != 'kmeans':
raise ValueError(f"Parametrization {parametrization} not supported for cluster center visualization.")
cluster_center = np.load(os.path.join(path_to_file,'cluster_center_'+file+'.npy'))
return visualize_cluster_center(cfg, model, cluster_center)

Expand Down
10 changes: 7 additions & 3 deletions src/vame/analysis/gif_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vame.util.gif_pose_helper import get_animal_frames
from typing import List, Tuple
from vame.logging.logger import VameLogger
from vame.schemas.project import Parametrizations



Expand Down Expand Up @@ -91,6 +92,7 @@
def gif(
config: str,
pose_ref_index: int,
parametrization: Parametrizations,
subtract_background: bool = True,
start: int | None = None,
length: int = 500,
Expand Down Expand Up @@ -120,7 +122,9 @@
cfg = read_config(config_file)
model_name = cfg['model_name']
n_cluster = cfg['n_cluster']
param = cfg['parametrization']

if parametrization not in cfg['parametrizations']:
raise ValueError("Parametrization not found in config")

Check warning on line 127 in src/vame/analysis/gif_creator.py

View check run for this annotation

Codecov / codecov/patch

src/vame/analysis/gif_creator.py#L127

Added line #L127 was not covered by tests

files = []
if cfg['all_data'] == 'No':
Expand All @@ -146,7 +150,7 @@


for file in files:
path_to_file=os.path.join(cfg['project_path'],"results",file,model_name,param+'-'+str(n_cluster),"")
path_to_file=os.path.join(cfg['project_path'],"results",file,model_name, parametrization+'-'+str(n_cluster),"")
if not os.path.exists(os.path.join(path_to_file,"gif_frames")):
os.mkdir(os.path.join(path_to_file,"gif_frames"))

Expand All @@ -173,7 +177,7 @@
np.save(os.path.join(path_to_file,"community","umap_embedding_"+file+'.npy'), embed)

if label == "motif":
umap_label = np.load(os.path.join(path_to_file,str(n_cluster)+"_" + param + "_label_"+file+'.npy'))
umap_label = np.load(os.path.join(path_to_file,str(n_cluster)+"_" + parametrization + "_label_"+file+'.npy'))
elif label == "community":
umap_label = np.load(os.path.join(path_to_file,"community","community_label_"+file+'.npy'))
elif label is None:
Expand Down
Loading