Skip to content

Commit

Permalink
Merge pull request #52 from EthoML/main
Browse files Browse the repository at this point in the history
Update docs
  • Loading branch information
vinicvaz authored Jul 1, 2024
2 parents cca3f52 + 619d8bd commit bfb1033
Show file tree
Hide file tree
Showing 32 changed files with 1,753 additions and 1,731 deletions.
1 change: 1 addition & 0 deletions .github/workflows/testing.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
branches:
- main
- dev
- fix/conditional-cuda-gm #todo remove


jobs:
Expand Down
85 changes: 12 additions & 73 deletions examples/demo.ipynb

Large diffs are not rendered by default.

5 changes: 0 additions & 5 deletions examples/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@
# After the inital creation of your project you can always access the config.yaml file
# via specifying the path to your project

# As our config.yaml is sometimes still changing a little due to updates, we have here a small function
# to update your config.yaml to the current state. Be aware that this will overwrite your current config.yaml
# and make sure to back up your version if you did parameter changes!
vame.update_config(config, force_update=False)

# Step 1.2: Align your behavior videos egocentric and create training dataset
# pose_ref_index: list of reference coordinate indices for alignment
# Example: 0: snout, 1: forehand_left, 2: forehand_right, 3: hindleft, 4: hindright, 5: tail
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vame-py"
version = '0.1.0'
version = '0.2.0'
dynamic = ["dependencies"]
description = "Variational Animal Motion Embedding."
authors = [
Expand Down
7 changes: 5 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
PyYAML>=6.0.0
ruamel.yaml>=0.18.0
numpy>=1.26.0
numpy==1.26.4
pandas>=2.2.0
scipy>=1.13.0
matplotlib>=3.9.0
Expand All @@ -9,4 +9,7 @@ tqdm>=4.66.0
hmmlearn>=0.3.0
opencv-python-headless>=4.9.0.0
umap-learn>=0.5.0
h5py>=3.11.0
h5py>=3.11.0
pydantic==2.7.4
imageio==2.34.1
imageio-ffmpeg==0.5.1
1 change: 0 additions & 1 deletion src/vame/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,4 @@
from vame.util.csv_to_npy import csv_to_numpy
from vame.util.align_egocentrical import egocentric_alignment
from vame.util import auxiliary
from vame.util.auxiliary import update_config

238 changes: 120 additions & 118 deletions src/vame/analysis/community_analysis.py

Large diffs are not rendered by default.

174 changes: 104 additions & 70 deletions src/vame/analysis/generative_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
from pathlib import Path
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture

from vame.schemas.states import GenerativeModelFunctionSchema, save_state
from vame.util.auxiliary import read_config
from vame.model.rnn_model import RNN_VAE
from vame.logging.logger import VameLogger

logger_config = VameLogger(__name__)
logger = logger_config.logger


def random_generative_samples_motif(
Expand All @@ -38,6 +42,7 @@ def random_generative_samples_motif(
Returns:
None: Plot of generated samples.
"""
logger.info('Generate random generative samples for motifs...')
time_window = cfg['time_window']
for j in range(n_cluster):

Expand All @@ -49,19 +54,24 @@ def random_generative_samples_motif(
density_sample = gm.sample(10)

# generate image via model decoder
tensor_sample = torch.from_numpy(density_sample[0]).type('torch.FloatTensor').cuda()
tensor_sample = torch.from_numpy(density_sample[0]).type('torch.FloatTensor')
if torch.cuda.is_available():
tensor_sample = tensor_sample.cuda()
else:
tensor_sample = tensor_sample.cpu()

decoder_inputs = tensor_sample.unsqueeze(2).repeat(1, 1, time_window)
decoder_inputs = decoder_inputs.permute(0,2,1)

image_sample = model.decoder(decoder_inputs, tensor_sample)
recon_sample = image_sample.cpu().detach().numpy()


fig, axs = plt.subplots(2,5)
for i in range(5):
axs[0,i].plot(recon_sample[i,...])
axs[1,i].plot(recon_sample[i+5,...])
plt.suptitle('Generated samples for motif '+str(j))
return fig

def random_generative_samples(cfg: dict, model: torch.nn.Module, latent_vector: np.ndarray) -> None:
"""Generate random generative samples.
Expand All @@ -74,6 +84,7 @@ def random_generative_samples(cfg: dict, model: torch.nn.Module, latent_vector:
Returns:
None
"""
logger.info('Generate random generative samples...')
# Latent sampling and generative model
time_window = cfg['time_window']
gm = GaussianMixture(n_components=10).fit(latent_vector)
Expand All @@ -82,7 +93,12 @@ def random_generative_samples(cfg: dict, model: torch.nn.Module, latent_vector:
density_sample = gm.sample(10)

# generate image via model decoder
tensor_sample = torch.from_numpy(density_sample[0]).type('torch.FloatTensor').cuda()
tensor_sample = torch.from_numpy(density_sample[0]).type('torch.FloatTensor')
if torch.cuda.is_available():
tensor_sample = tensor_sample.cuda()
else:
tensor_sample = tensor_sample.cpu()

decoder_inputs = tensor_sample.unsqueeze(2).repeat(1, 1, time_window)
decoder_inputs = decoder_inputs.permute(0,2,1)

Expand All @@ -94,6 +110,7 @@ def random_generative_samples(cfg: dict, model: torch.nn.Module, latent_vector:
axs[0,i].plot(recon_sample[i,...])
axs[1,i].plot(recon_sample[i+5,...])
plt.suptitle('Generated samples')
return fig


def random_reconstruction_samples(cfg: dict, model: torch.nn.Module, latent_vector: np.ndarray) -> None:
Expand All @@ -107,11 +124,17 @@ def random_reconstruction_samples(cfg: dict, model: torch.nn.Module, latent_vect
Returns:
None
"""
logger.info('Generate random reconstruction samples...')
# random samples for reconstruction
time_window = cfg['time_window']

rnd = np.random.choice(latent_vector.shape[0], 10)
tensor_sample = torch.from_numpy(latent_vector[rnd]).type('torch.FloatTensor').cuda()
tensor_sample = torch.from_numpy(latent_vector[rnd]).type('torch.FloatTensor')
if torch.cuda.is_available():
tensor_sample = tensor_sample.cuda()
else:
tensor_sample = tensor_sample.cpu()

decoder_inputs = tensor_sample.unsqueeze(2).repeat(1, 1, time_window)
decoder_inputs = decoder_inputs.permute(0,2,1)

Expand All @@ -123,6 +146,7 @@ def random_reconstruction_samples(cfg: dict, model: torch.nn.Module, latent_vect
axs[0,i].plot(recon_sample[i,...])
axs[1,i].plot(recon_sample[i+5,...])
plt.suptitle('Reconstructed samples')
return fig


def visualize_cluster_center(cfg: dict, model: torch.nn.Module, cluster_center: np.ndarray) -> None:
Expand All @@ -137,10 +161,15 @@ def visualize_cluster_center(cfg: dict, model: torch.nn.Module, cluster_center:
None
"""
#Cluster Center
logger.info('Visualize cluster center...')
time_window = cfg['time_window']
animal_centers = cluster_center

tensor_sample = torch.from_numpy(animal_centers).type('torch.FloatTensor').cuda()
tensor_sample = torch.from_numpy(animal_centers).type('torch.FloatTensor')
if torch.cuda.is_available():
tensor_sample = tensor_sample.cuda()
else:
tensor_sample = tensor_sample.cpu()
decoder_inputs = tensor_sample.unsqueeze(2).repeat(1, 1, time_window)
decoder_inputs = decoder_inputs.permute(0,2,1)

Expand All @@ -157,6 +186,7 @@ def visualize_cluster_center(cfg: dict, model: torch.nn.Module, cluster_center:
axs[k,i].plot(recon_sample[idx,...])
axs[k,i].set_title("Cluster %d" %idx)
idx +=1
return fig


def load_model(cfg: dict, model_name: str) -> torch.nn.Module:
Expand Down Expand Up @@ -186,85 +216,89 @@ def load_model(cfg: dict, model_name: str) -> torch.nn.Module:
dropout_pred = cfg['dropout_pred']
softplus = cfg['softplus']

print('Load model... ')
logger.info('Loading model... ')

model = RNN_VAE(TEMPORAL_WINDOW,ZDIMS,NUM_FEATURES,FUTURE_DECODER,FUTURE_STEPS, hidden_size_layer_1,
hidden_size_layer_2, hidden_size_rec, hidden_size_pred, dropout_encoder,
dropout_rec, dropout_pred, softplus).cuda()
dropout_rec, dropout_pred, softplus)
if torch.cuda.is_available():
model = model.cuda()
else:
model = model.cpu()

model.load_state_dict(torch.load(os.path.join(cfg['project_path'],'model','best_model',model_name+'_'+cfg['Project']+'.pkl')))
model.eval()

return model


def generative_model(config: str, mode: str = "sampling") -> None:
@save_state(model=GenerativeModelFunctionSchema)
def generative_model(config: str, mode: str = "sampling", save_logs: bool = False) -> plt.Figure:
"""Generative model.
Args:
config (str): Path to the configuration file.
mode (str, optional): Mode for generating samples. Defaults to "sampling".
Returns:
None
plt.Figure: Plot of generated samples.
"""
config_file = Path(config).resolve()
cfg = read_config(config_file)
model_name = cfg['model_name']
n_cluster = cfg['n_cluster']
parametrization = cfg['parametrization']

files = []
if cfg['all_data'] == 'No':
all_flag = input("Do you want to write motif videos for your entire dataset? \n"
"If you only want to use a specific dataset type filename: \n"
"yes/no/filename ")
else:
all_flag = 'yes'

if all_flag == 'yes' or all_flag == 'Yes':
for file in cfg['video_sets']:
files.append(file)

elif all_flag == 'no' or all_flag == 'No':
for file in cfg['video_sets']:
use_file = input("Do you want to quantify " + file + "? yes/no: ")
if use_file == 'yes':
try:
config_file = Path(config).resolve()
cfg = read_config(config_file)
if save_logs:
logs_path = Path(cfg['project_path']) / "logs" / 'generative_model.log'
logger_config.add_file_handler(logs_path)
logger.info(f'Running generative model with mode {mode}...')
model_name = cfg['model_name']
n_cluster = cfg['n_cluster']
parametrization = cfg['parametrization']

files = []
if cfg['all_data'] == 'No':
all_flag = input("Do you want to write motif videos for your entire dataset? \n"
"If you only want to use a specific dataset type filename: \n"
"yes/no/filename ")
else:
all_flag = 'yes'

if all_flag == 'yes' or all_flag == 'Yes':
for file in cfg['video_sets']:
files.append(file)
if use_file == 'no':
continue
else:
files.append(all_flag)


model = load_model(cfg, model_name)

for file in files:
path_to_file=os.path.join(cfg['project_path'],"results",file,model_name, parametrization + '-' +str(n_cluster),"")

if mode == "sampling":
latent_vector = np.load(os.path.join(path_to_file,'latent_vector_'+file+'.npy'))
random_generative_samples(cfg, model, latent_vector)

if mode == "reconstruction":
latent_vector = np.load(os.path.join(path_to_file,'latent_vector_'+file+'.npy'))
random_reconstruction_samples(cfg, model, latent_vector)

if mode == "centers":
cluster_center = np.load(os.path.join(path_to_file,'cluster_center_'+file+'.npy'))
visualize_cluster_center(cfg, model, cluster_center)

if mode == "motifs":
latent_vector = np.load(os.path.join(path_to_file,'latent_vector_'+file+'.npy'))
labels = np.load(os.path.join(path_to_file,"",str(n_cluster)+'_' + parametrization + '_label_'+file+'.npy'))
random_generative_samples_motif(cfg, model, latent_vector,labels,n_cluster)











elif all_flag == 'no' or all_flag == 'No':
for file in cfg['video_sets']:
use_file = input("Do you want to quantify " + file + "? yes/no: ")
if use_file == 'yes':
files.append(file)
if use_file == 'no':
continue
else:
files.append(all_flag)


model = load_model(cfg, model_name)

for file in files:
path_to_file=os.path.join(cfg['project_path'],"results",file,model_name, parametrization + '-' +str(n_cluster),"")

if mode == "sampling":
latent_vector = np.load(os.path.join(path_to_file,'latent_vector_'+file+'.npy'))
return random_generative_samples(cfg, model, latent_vector)

if mode == "reconstruction":
latent_vector = np.load(os.path.join(path_to_file,'latent_vector_'+file+'.npy'))
return random_reconstruction_samples(cfg, model, latent_vector)

if mode == "centers":
cluster_center = np.load(os.path.join(path_to_file,'cluster_center_'+file+'.npy'))
return visualize_cluster_center(cfg, model, cluster_center)

if mode == "motifs":
latent_vector = np.load(os.path.join(path_to_file,'latent_vector_'+file+'.npy'))
labels = np.load(os.path.join(path_to_file,"",str(n_cluster)+'_' + parametrization + '_label_'+file+'.npy'))
return random_generative_samples_motif(cfg, model, latent_vector,labels,n_cluster)
except Exception as e:
logger.exception(str(e))
raise
finally:
logger_config.remove_file_handler()
Loading

0 comments on commit bfb1033

Please sign in to comment.