Skip to content

Commit

Permalink
chore: merge dev into release
Browse files Browse the repository at this point in the history
chore: merge dev into release
  • Loading branch information
versey-sherry authored Apr 5, 2023
2 parents d4a63eb + 7a65fb3 commit 93052b4
Show file tree
Hide file tree
Showing 13 changed files with 208 additions and 47 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
stage: latest-pythons
before_install:
- pip install -U pip
- pip install numpy==1.18.3
- pip install pytest==5.4.1 codecov pytest-cov
- export PYTHONPATH=$PYTHONPATH:$(pwd)
install:
Expand Down
2 changes: 1 addition & 1 deletion moseq2_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = 'v1.1.2'
__version__ = 'v1.2.0'
20 changes: 17 additions & 3 deletions moseq2_model/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import click
from os.path import join
from moseq2_model.util import count_frames as count_frames_wrapper
from moseq2_model.helpers.wrappers import learn_model_wrapper, kappa_scan_fit_models_wrapper
from moseq2_model.helpers.wrappers import learn_model_wrapper, kappa_scan_fit_models_wrapper, apply_model_wrapper

orig_init = click.core.Option.__init__

Expand Down Expand Up @@ -46,7 +46,7 @@ def modeling_parameters(function):
function = click.option('--e-step', is_flag=True, help="Compute the expected state sequence for each recordings")(function)
function = click.option("--save-every", "-s", type=int, default=-1,
help="Increment to save labels and model object (-1 for just last)")(function)
function = click.option("--save-model", is_flag=True, help="Save model object at the end of training")(function)
function = click.option("--save-model", type=bool, default=True, help="Save model object at the end of training")(function)
function = click.option("--max-states", "-m", type=int, default=100, help="Maximum number of states")(function)
function = click.option("--npcs", type=int, default=10, help="Number of PCs to use")(function)
function = click.option("--whiten", "-w", type=str, default='all', help="Whiten PCs: (e)each session (a)ll combined or (n)o whitening")(function)
Expand All @@ -73,14 +73,28 @@ def modeling_parameters(function):
@click.option("--kappa", "-k", type=float, default=None, help="Kappa; hyperparameter used to set syllable duration. Larger k = longer syllable lengths")
@click.option("--checkpoint-freq", type=int, default=-1, help='save model checkpoint every n iterations')
@click.option("--use-checkpoint", is_flag=True, help='indicate whether to use previously saved checkpoint')
@click.option("--index", "-i", type=click.Path(), default="", help="Path to moseq2-index.yaml for group definitions (used only with the separate-trans flag)")
@click.option("--index", "-i", type=click.Path(), default="", help="Path to moseq2-index.yaml for group definitions")
@click.option("--default-group", type=str, default="n/a", help="Default group name to use for separate-trans")
@click.option("--verbose", '-v', is_flag=True, help="Print syllable log-likelihoods during training.")
def learn_model(input_file, dest_file, **config_data):
# Train the ARHMM using PC scores located in the INPUT_FILE, and saves the model to DEST_FILE

learn_model_wrapper(input_file, dest_file, config_data)


@cli.command(name='apply-model', help='Apply pre-trained ARHMM to PC scores.')
@click.argument("model_file", type=click.Path(exists=True))
@click.argument("pc_file", type=click.Path(exists=True))
@click.argument("dest_file", type=click.Path(file_okay=True, writable=True, resolve_path=True))
@click.option("--var-name", type=str, default='scores', help="Variable name in input file with PCs")
@click.option("--index", "-i", type=click.Path(), default="", help="Path to moseq2-index.yaml for group definitions")
@click.option("--load-groups", type=bool, default=True, help="If groups should be loaded with the PC scores.")
def apply_model(model_file, pc_file, dest_file, **config_data):
# Apply the ARHMM located in MODEL_FILE to the PC scores in PC_FILE, and saves the results to DEST_FILE

apply_model_wrapper(model_file, pc_file, dest_file, config_data)


@cli.command(name='kappa-scan', help='Batch train multiple model to scan over different kappa values.')
@click.argument('input_file', type=click.Path(exists=True))
@click.argument('output_dir', type=click.Path(exists=False))
Expand Down
27 changes: 25 additions & 2 deletions moseq2_model/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import ruamel.yaml as yaml
from moseq2_model.cli import learn_model, kappa_scan_fit_models
from os.path import dirname, join, exists
from moseq2_model.helpers.wrappers import learn_model_wrapper, kappa_scan_fit_models_wrapper
from moseq2_model.helpers.wrappers import learn_model_wrapper, kappa_scan_fit_models_wrapper, apply_model_wrapper

def learn_model_command(progress_paths, get_cmd=True, verbose=False):
"""
Expand Down Expand Up @@ -60,4 +60,27 @@ def learn_model_command(progress_paths, get_cmd=True, verbose=False):
command = kappa_scan_fit_models_wrapper(input_file, config_data, output_dir)
return command
else:
learn_model_wrapper(input_file, dest_file, config_data)
learn_model_wrapper(input_file, dest_file, config_data)


def apply_model_command(progress_paths, model_file):
"""Apply a pre-trained ARHMM to a new dataset from within a Jupyter notebook.
Args:
progress_paths (dict): notebook progress dict that contains paths to the pc scores, config, and index files.
model_file (str): path to the pre-trained ARHMM.
"""

# Load proper input variables
pc_file = progress_paths['scores_path']
dest_file = progress_paths['model_path']
config_file = progress_paths['config_file']
index = progress_paths['index_file']
output_dir = progress_paths['base_model_path']

# load config data
with open(config_file, 'r') as f:
config_data = yaml.safe_load(f)

# apply model to data
apply_model_wrapper(model_file, pc_file, dest_file, config_data)
9 changes: 6 additions & 3 deletions moseq2_model/helpers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,15 @@ def prepare_model_metadata(data_dict, data_metadata, config_data):
model_parameters['groups'] = {k: data_metadata['groups'][k] for k in train}

# Whiten the data
whitening_parameters = None
if config_data['whiten'][0].lower() == 'a':
click.echo('Whitening the training data using the whiten_all function')
data_dict = whiten_all(data_dict)
# in this case, whitening_parameters is a single tuple
data_dict, whitening_parameters = whiten_all(data_dict)
elif config_data['whiten'][0].lower() == 'e':
click.echo('Whitening the training data using the whiten_each function')
data_dict = whiten_each(data_dict)
# in this case, whitening_parameters is a dictionary of parameters
data_dict, whitening_parameters = whiten_each(data_dict)
else:
click.echo('Not whitening the data')

Expand All @@ -189,7 +192,7 @@ def prepare_model_metadata(data_dict, data_metadata, config_data):
for k, v in data_dict.items():
data_dict[k] = v + np.random.randn(*v.shape) * config_data['noise_level']

return data_dict, model_parameters, train, hold_out
return data_dict, model_parameters, train, hold_out, whitening_parameters


def get_heldout_data_splits(data_dict, train_list, hold_out_list):
Expand Down
77 changes: 70 additions & 7 deletions moseq2_model/helpers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import sys
import glob
import click
import numpy as np
from copy import deepcopy
from collections import OrderedDict
from moseq2_model.train.util import train_model, run_e_step
from os.path import join, basename, realpath, dirname, exists, splitext
from cytoolz import valmap
from moseq2_model.train.util import train_model, run_e_step, apply_model
from os.path import join, basename, realpath, dirname, splitext
from moseq2_model.util import (save_dict, load_pcs, get_parameters_from_model, copy_model, get_scan_range_kappas,
create_command_strings, get_current_model, get_loglikelihoods, get_session_groupings)
create_command_strings, get_current_model, get_loglikelihoods, get_session_groupings, load_dict)
from moseq2_model.helpers.data import (process_indexfile, select_data_to_model, prepare_model_metadata,
graph_modeling_loglikelihoods, get_heldout_data_splits, get_training_data_splits)

Expand Down Expand Up @@ -56,20 +57,22 @@ def learn_model_wrapper(input_file, dest_file, config_data):
load_groups=config_data['load_groups'])

# Parse index file and update metadata information; namely groups
# If no group data in pca data, use group info from index file
select_groups = config_data.get('select_groups', False)
index_data, data_metadata = process_indexfile(config_data.get('index', None), data_metadata,
config_data['default_group'], select_groups)

# Get keys to include in training set
# TODO: select_groups not implemented
if index_data is not None:
data_dict, data_metadata = select_data_to_model(index_data, data_dict,
data_metadata, select_groups)

all_keys = list(data_dict)
groups = list(data_metadata['groups'].values())

# Get train/held out data split uuids
data_dict, model_parameters, train_list, hold_out_list = \
data_dict, model_parameters, train_list, hold_out_list, whitening_parameters = \
prepare_model_metadata(data_dict, data_metadata, config_data)

# Pack data dicts corresponding to uuids in train_list and hold_out_list
Expand Down Expand Up @@ -154,7 +157,9 @@ def learn_model_wrapper(input_file, dest_file, config_data):
'hold_out_list': hold_out_list,
'train_list': train_list,
'train_ll': train_ll,
'expected_states': expected_states if config_data['e_step'] else None
'expected_states': expected_states if config_data['e_step'] else None,
'whitening_parameters': whitening_parameters,
'pc_score_path': os.path.abspath(input_file)
}

# Save model
Expand All @@ -168,6 +173,64 @@ def learn_model_wrapper(input_file, dest_file, config_data):
return img_path


def apply_model_wrapper(model_file, pc_file, dest_file, config_data):
"""
Wrapper function to apply a pre-trained model to new data.
Args:
model_file (str): Path to pre-trained model file
pc_file (str): Path to PC scores file
dest_file (str): Path to save output file
Returns:
None
"""

assert splitext(basename(dest_file))[-1] in ['.mat', '.z', '.pkl', '.p', '.h5'], 'Incorrect model filetype'
os.makedirs(dirname(dest_file), exist_ok=True)

if not os.access(dirname(dest_file), os.W_OK):
raise IOError('Output directory is not writable.')


# Load model
model_data = load_dict(model_file)

if model_data.get('whitening_parameters') is None:
raise KeyError('Whitening parameters not found in model file. Unable to apply model to new data. Please retrain the model using the latest version.')

# Load PC scores
data_dict, data_metadata = load_pcs(filename=pc_file, var_name=config_data.get('var_name', 'scores'), npcs=model_data['run_parameters']['npcs'],
load_groups=config_data.get('load_groups', False))

# parse group information from index file
index_data, data_metadata = process_indexfile(config_data.get('index', None), data_metadata,
config_data.get('default_group', 'n/a'), select_groups=False)

# Apply model
syllables = apply_model(model_data['model'], model_data['whitening_parameters'], data_dict, data_metadata, model_data['run_parameters']['whiten'])

# add -5 padding to the list of states
nlags = model_data['run_parameters'].get('nlags', 3)
syllables = valmap(lambda v: np.concatenate(([-5] * nlags, v)), syllables)

# prepare model data dictionary to save
# save applied model data
applied_model_data = {}
applied_model_data['labels'] = list(syllables.values())
applied_model_data['keys'] = list(syllables.keys())
applied_model_data['metadata'] = data_metadata
applied_model_data['pc_score_path'] = os.path.abspath(pc_file)
applied_model_data['pre_trained_model_path'] = os.path.abspath(model_file)

# copy over pre-trained model data
for key in ['model_parameters', 'run_parameters', 'model', 'whitening_parameters']:
applied_model_data[key] = model_data[key]

# Save output
save_dict(filename=dest_file, obj_to_save=applied_model_data)


def kappa_scan_fit_models_wrapper(input_file, config_data, output_dir):
"""
Wrapper function to output multiple model training commands for a range of kappa values.
Expand Down
4 changes: 2 additions & 2 deletions moseq2_model/train/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
ARHMM model initialization utilities.
ARHMM initialization utilities.
"""

import warnings
Expand Down Expand Up @@ -54,7 +54,7 @@ def ARHMM(data_dict, kappa=1e6, gamma=999, nlags=3, alpha=5.7,
affine=True, model_hypparams={}, obs_hypparams={}, sticky_init=False,
separate_trans=False, groups=None, robust=False, silent=False):
"""
Initialize ARHMM and add data and group labels to the ARHMM model.
Initialize ARHMM and add data and group labels to the ARHMM.
Args:
data_dict (OrderedDict): training data to add to model
Expand Down
56 changes: 46 additions & 10 deletions moseq2_model/train/util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
"""
ARHMM utility functions
"""

import math
import numpy as np
from cytoolz import valmap
from tqdm.auto import tqdm
from scipy.stats import norm
from functools import partial
from cytoolz import valmap, itemmap
from collections import OrderedDict, defaultdict
from moseq2_model.util import save_arhmm_checkpoint, get_loglikelihoods

Expand Down Expand Up @@ -147,7 +144,7 @@ def get_labels_from_model(model):
Grab model labels for each training dataset and place them in a list.
Args:
model (ARHMM): trained ARHMM model
model (ARHMM): trained ARHMM
Returns:
labels (list): An array of predicted syllable labels for each training session
Expand All @@ -157,6 +154,44 @@ def get_labels_from_model(model):
return labels


def apply_model(model, whitening_params, data_dict, metadata, whiten='all'):
'''
Apply pre-trained model to data_dict. Note that this function might produce unexpected behavior
if the model was trained using separate transition matrices for different groups of sessions.
Args:
model (ARHMM): pre-trained model
whitening_params (namedtuple or dict): whitening parameters
data_dict (OrderedDict): data to apply model to
metadata (dict): metadata for data_dict
Returns:
labels (dict): dictionary of labels predicted per session after modeling
'''

# whiten data function
mu, L, offset = whitening_params['mu'], whitening_params['L'], whitening_params['offset']
apply_whitening = lambda x: np.linalg.solve(L, (x-mu).T).T + offset

# check for whiten parameters to see if whiten_all or whiten_each
if whiten[0].lower() == 'e':
# this approach is not recommended, but supported
center = whitening_params[list(whitening_params)[0]]['offset'] == 0
whitened_data, _ = whiten_each(data_dict, center)
else:
whitened_data = valmap(apply_whitening, data_dict)

# apply model to data
if 'SeparateTrans' in str(type(model)):
# not recommended, but supported
labels = itemmap(lambda item: (item[0], model.heldout_viterbi(item[1], group_id=metadata['groups'][item[0]])), whitened_data)
else:
labels = valmap(model.heldout_viterbi, whitened_data)

return labels



# taken from moseq by @mattjj and @alexbw
def whiten_all(data_dict, center=True):
"""
Expand All @@ -178,9 +213,10 @@ def whiten_all(data_dict, center=True):
L = np.linalg.cholesky(Sigma)

offset = 0. if center else mu
# set up function to whiten data
apply_whitening = lambda x: np.linalg.solve(L, (x-mu).T).T + offset

return OrderedDict((k, contig(apply_whitening(v))) for k, v in data_dict.items())
whitening_parameters = {'mu': mu, 'L': L, 'offset': offset}
return OrderedDict((k, contig(apply_whitening(v))) for k, v in data_dict.items()), whitening_parameters


# taken from moseq by @mattjj and @alexbw
Expand All @@ -195,12 +231,12 @@ def whiten_each(data_dict, center=True):
Returns:
data_dict (OrderedDict): Whitened training data dictionary
"""

whitening_parameters = {}
for k, v in data_dict.items():
tmp_dict = whiten_all({k: v}, center=center)
tmp_dict, whitening_parameters[k] = whiten_all({k: v}, center=center)
data_dict[k] = tmp_dict[k]

return data_dict
return data_dict, whitening_parameters


def run_e_step(arhmm):
Expand Down
Loading

0 comments on commit 93052b4

Please sign in to comment.