From e5178c827839bc44eff1f051ca22541cf41ee203 Mon Sep 17 00:00:00 2001 From: wingillis Date: Thu, 23 Mar 2023 18:47:51 -0400 Subject: [PATCH 01/29] Add ability to apply pre-trained model - add CLI option called `apply-model` that accepts a path to a pre-trained model and to pcs - added function to load a dictionary from a saved file - added ability to save whitening parameters used during modeling - saves a new file with predicted syllables --- moseq2_model/cli.py | 17 +++++++- moseq2_model/helpers/data.py | 9 +++-- moseq2_model/helpers/wrappers.py | 38 +++++++++++++++--- moseq2_model/train/util.py | 69 +++++++++++++++++++++++++++----- moseq2_model/util.py | 23 ++++++++++- 5 files changed, 133 insertions(+), 23 deletions(-) diff --git a/moseq2_model/cli.py b/moseq2_model/cli.py index 864ff5d..a861bbb 100644 --- a/moseq2_model/cli.py +++ b/moseq2_model/cli.py @@ -6,7 +6,7 @@ import click from os.path import join from moseq2_model.util import count_frames as count_frames_wrapper -from moseq2_model.helpers.wrappers import learn_model_wrapper, kappa_scan_fit_models_wrapper +from moseq2_model.helpers.wrappers import learn_model_wrapper, kappa_scan_fit_models_wrapper, apply_model_wrapper orig_init = click.core.Option.__init__ @@ -46,7 +46,7 @@ def modeling_parameters(function): function = click.option('--e-step', is_flag=True, help="Compute the expected state sequence for each recordings")(function) function = click.option("--save-every", "-s", type=int, default=-1, help="Increment to save labels and model object (-1 for just last)")(function) - function = click.option("--save-model", is_flag=True, help="Save model object at the end of training")(function) + function = click.option("--save-model", type=bool, default=True, help="Save model object at the end of training")(function) function = click.option("--max-states", "-m", type=int, default=100, help="Maximum number of states")(function) function = click.option("--npcs", type=int, default=10, help="Number of PCs to use")(function) function = click.option("--whiten", "-w", type=str, default='all', help="Whiten PCs: (e)each session (a)ll combined or (n)o whitening")(function) @@ -81,6 +81,19 @@ def learn_model(input_file, dest_file, **config_data): learn_model_wrapper(input_file, dest_file, config_data) + +@cli.command(name='apply-model', help='Apply trained ARHMM model to PC scores.') +@click.argument("model_file", type=click.Path(exists=True)) +@click.argument("pc_file", type=click.Path(exists=True)) +@click.argument("dest_file", type=click.Path(file_okay=True, writable=True, resolve_path=True)) +@click.option("--var-name", type=str, default='scores', help="Variable name in input file with PCs") +@click.option("--load-groups", type=bool, default=True, help="If groups should be loaded with the PC scores.") +def apply_model(model_file, pc_file, dest_file, **config_data): + # Apply the ARHMM model located in MODEL_FILE to the PC scores in PC_FILE, and saves the results to DEST_FILE + + apply_model_wrapper(model_file, pc_file, dest_file, config_data) + + @cli.command(name='kappa-scan', help='Batch train multiple model to scan over different kappa values.') @click.argument('input_file', type=click.Path(exists=True)) @click.argument('output_dir', type=click.Path(exists=False)) diff --git a/moseq2_model/helpers/data.py b/moseq2_model/helpers/data.py index 06aba2e..7a47540 100644 --- a/moseq2_model/helpers/data.py +++ b/moseq2_model/helpers/data.py @@ -174,12 +174,15 @@ def prepare_model_metadata(data_dict, data_metadata, config_data): model_parameters['groups'] = {k: data_metadata['groups'][k] for k in train} # Whiten the data + whitening_parameters = None if config_data['whiten'][0].lower() == 'a': click.echo('Whitening the training data using the whiten_all function') - data_dict = whiten_all(data_dict) + # in this case, whitening_parameters is a single tuple + data_dict, whitening_parameters = whiten_all(data_dict) elif config_data['whiten'][0].lower() == 'e': click.echo('Whitening the training data using the whiten_each function') - data_dict = whiten_each(data_dict) + # in this case, whitening_parameters is a dictionary of parameters + data_dict, whitening_parameters = whiten_each(data_dict) else: click.echo('Not whitening the data') @@ -189,7 +192,7 @@ def prepare_model_metadata(data_dict, data_metadata, config_data): for k, v in data_dict.items(): data_dict[k] = v + np.random.randn(*v.shape) * config_data['noise_level'] - return data_dict, model_parameters, train, hold_out + return data_dict, model_parameters, train, hold_out, whitening_parameters def get_heldout_data_splits(data_dict, train_list, hold_out_list): diff --git a/moseq2_model/helpers/wrappers.py b/moseq2_model/helpers/wrappers.py index ad5745f..d9d8579 100644 --- a/moseq2_model/helpers/wrappers.py +++ b/moseq2_model/helpers/wrappers.py @@ -7,11 +7,10 @@ import glob import click from copy import deepcopy -from collections import OrderedDict -from moseq2_model.train.util import train_model, run_e_step -from os.path import join, basename, realpath, dirname, exists, splitext +from moseq2_model.train.util import train_model, run_e_step, apply_model +from os.path import join, basename, realpath, dirname, splitext from moseq2_model.util import (save_dict, load_pcs, get_parameters_from_model, copy_model, get_scan_range_kappas, - create_command_strings, get_current_model, get_loglikelihoods, get_session_groupings) + create_command_strings, get_current_model, get_loglikelihoods, get_session_groupings, load_dict) from moseq2_model.helpers.data import (process_indexfile, select_data_to_model, prepare_model_metadata, graph_modeling_loglikelihoods, get_heldout_data_splits, get_training_data_splits) @@ -69,7 +68,7 @@ def learn_model_wrapper(input_file, dest_file, config_data): groups = list(data_metadata['groups'].values()) # Get train/held out data split uuids - data_dict, model_parameters, train_list, hold_out_list = \ + data_dict, model_parameters, train_list, hold_out_list, whitening_parameters = \ prepare_model_metadata(data_dict, data_metadata, config_data) # Pack data dicts corresponding to uuids in train_list and hold_out_list @@ -154,7 +153,8 @@ def learn_model_wrapper(input_file, dest_file, config_data): 'hold_out_list': hold_out_list, 'train_list': train_list, 'train_ll': train_ll, - 'expected_states': expected_states if config_data['e_step'] else None + 'expected_states': expected_states if config_data['e_step'] else None, + 'whitening_parameters': whitening_parameters, } # Save model @@ -168,6 +168,32 @@ def learn_model_wrapper(input_file, dest_file, config_data): return img_path +def apply_model_wrapper(model_file, pc_file, dest_file, config_data): + """ + Wrapper function to apply a trained model to new data. + + Args: + model_file (str): Path to trained model file + pc_file (str): Path to PC scores file + dest_file (str): Path to save output file + + Returns: + None + """ + # Load model + model_data = load_dict(model_file) + + # Load PC scores + data_dict, data_metadata = load_pcs(filename=pc_file, var_name=config_data.get('var_name', 'scores'), npcs=model_data['run_parameters']['npcs'], + load_groups=config_data.get('load_groups', False)) + + # Apply model + syllables = apply_model(model_data['model'], model_data['whitening_parameters'], data_dict, data_metadata) + + # Save output + save_dict(filename=dest_file, obj_to_save=syllables) + + def kappa_scan_fit_models_wrapper(input_file, config_data, output_dir): """ Wrapper function to output multiple model training commands for a range of kappa values. diff --git a/moseq2_model/train/util.py b/moseq2_model/train/util.py index e8ce2ca..0e716cf 100644 --- a/moseq2_model/train/util.py +++ b/moseq2_model/train/util.py @@ -1,14 +1,12 @@ """ ARHMM utility functions """ - -import math +from typing import Union import numpy as np -from cytoolz import valmap from tqdm.auto import tqdm -from scipy.stats import norm from functools import partial -from collections import OrderedDict, defaultdict +from cytoolz import valmap, itemmap +from collections import OrderedDict, defaultdict, namedtuple from moseq2_model.util import save_arhmm_checkpoint, get_loglikelihoods @@ -157,6 +155,56 @@ def get_labels_from_model(model): return labels +WhiteningParams = namedtuple('WhiteningParams', ['mu', 'L', 'offset']) + + +def apply_model(model, whitening_params: Union[WhiteningParams, dict], data_dict, metadata): + ''' + Apply trained model to data_dict. Note that this function might produce unexpected behavior + if the model was trained using separate transition matrices for different groups of sessions. + + Args: + model (ARHMM): trained model + whitening_params (namedtuple or dict): whitening parameters + data_dict (OrderedDict): data to apply model to + metadata (dict): metadata for data_dict + + Returns: + labels (dict): dictionary of labels predicted per session after modeling + ''' + if isinstance(whitening_params, dict): + # this approach is not recommended, but supported + center = whitening_params[list(whitening_params)[0]].offset == 0 + whitened_data, _ = whiten_each(data_dict, center) + else: + whitened_data = valmap(lambda x: apply_whitening(x, whitening_params.mu, whitening_params.L, whitening_params.offset), data_dict) + + # apply model to data + if 'SeparateTrans' in type(model): + # not recommended, but supported + labels = itemmap(lambda item: (item[0], model.heldout_viterbi(item[1], group_id=metadata['groups'][item[0]])), whitened_data) + else: + labels = valmap(model.heldout_viterbi, whitened_data) + + return labels + + +def apply_whitening(data, mu, L, offset=0): + '''Apply whitening to data. + + Args: + data (np.ndarray): data to be whitened + mu (np.ndarray): mean of data + L (np.ndarray): Cholesky decomposition of covariance matrix + offset (float): offset to add to whitened data + + Returns: + data (np.ndarray): whitened data + ''' + + return np.linalg.solve(L, (data - mu).T).T + offset + + # taken from moseq by @mattjj and @alexbw def whiten_all(data_dict, center=True): """ @@ -178,9 +226,8 @@ def whiten_all(data_dict, center=True): L = np.linalg.cholesky(Sigma) offset = 0. if center else mu - apply_whitening = lambda x: np.linalg.solve(L, (x-mu).T).T + offset - - return OrderedDict((k, contig(apply_whitening(v))) for k, v in data_dict.items()) + apply_whitening = partial(apply_whitening, mu=mu, L=L, offset=offset) + return OrderedDict((k, contig(apply_whitening(v))) for k, v in data_dict.items()), WhiteningParams(mu, L, offset) # taken from moseq by @mattjj and @alexbw @@ -195,12 +242,12 @@ def whiten_each(data_dict, center=True): Returns: data_dict (OrderedDict): Whitened training data dictionary """ - + whitening_parameters = {} for k, v in data_dict.items(): - tmp_dict = whiten_all({k: v}, center=center) + tmp_dict, whitening_parameters[k] = whiten_all({k: v}, center=center) data_dict[k] = tmp_dict[k] - return data_dict + return data_dict, whitening_parameters def run_e_step(arhmm): diff --git a/moseq2_model/util.py b/moseq2_model/util.py index 4716d5e..5c9737b 100644 --- a/moseq2_model/util.py +++ b/moseq2_model/util.py @@ -1,11 +1,11 @@ """ Utility functions for handling loading and saving models and their respective metadata. """ +import os import re import h5py import click import joblib -import pickle import scipy.io import warnings import numpy as np @@ -87,6 +87,9 @@ def load_pcs(filename, var_name="features", load_groups=False, npcs=10): else: warnings.warn('groups key not found in h5 file, assigning each session to unique group...') metadata['groups'] = {key: i for i, key in enumerate(data_dict)} + else: + warnings.warn('groups not loaded from h5 file. Assigning each session to unique group') + metadata['groups'] = {key: i for i, key in enumerate(data_dict)} else: raise IOError('Could not load data from h5 file') else: @@ -232,6 +235,24 @@ def save_dict(filename, obj_to_save=None): else: raise ValueError('Did not understand filetype') + +def load_dict(filename): + """ + Load dictionary from file. + + Args: + filename (str): path to file where dict is saved + + Returns: + obj (dict): loaded dictionary + """ + if filename.endswith(".h5"): + return h5_to_dict(filename) + elif filename.endswith(("pkl", "p", "z")): + return joblib.load(filename) + else: + raise ValueError(f"Does not support filetype {os.path.splitext(filename)[1]}") + def dict_to_h5(h5file, export_dict, path='/'): """ From 06fa80cfe5b0cff7ed28b13bad966385ed7a4413 Mon Sep 17 00:00:00 2001 From: Sherry Date: Tue, 28 Mar 2023 20:12:20 -0400 Subject: [PATCH 02/29] chore: add docs for parameters --- moseq2_model/helpers/wrappers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/moseq2_model/helpers/wrappers.py b/moseq2_model/helpers/wrappers.py index ad5745f..72fa996 100644 --- a/moseq2_model/helpers/wrappers.py +++ b/moseq2_model/helpers/wrappers.py @@ -61,14 +61,16 @@ def learn_model_wrapper(input_file, dest_file, config_data): config_data['default_group'], select_groups) # Get keys to include in training set + # If no group data in pca data, use group info from index file if index_data is not None: data_dict, data_metadata = select_data_to_model(index_data, data_dict, data_metadata, select_groups) - + all_keys = list(data_dict) groups = list(data_metadata['groups'].values()) # Get train/held out data split uuids + # Whiten data and get model parameters data_dict, model_parameters, train_list, hold_out_list = \ prepare_model_metadata(data_dict, data_metadata, config_data) From b959ceff2786674e934bef5aaf8ee156f0abcf3f Mon Sep 17 00:00:00 2001 From: Sherry Date: Wed, 29 Mar 2023 14:33:52 -0400 Subject: [PATCH 03/29] chore: add more info to warning message --- moseq2_model/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/moseq2_model/util.py b/moseq2_model/util.py index 5c9737b..e8fb2a2 100644 --- a/moseq2_model/util.py +++ b/moseq2_model/util.py @@ -85,10 +85,10 @@ def load_pcs(filename, var_name="features", load_groups=False, npcs=10): if 'groups' in f: metadata['groups'] = {key: f['groups'][i] for i, key in enumerate(data_dict) if key in f['metadata']} else: - warnings.warn('groups key not found in h5 file, assigning each session to unique group...') + warnings.warn('groups key not found in h5 file, assigning each session to unique group if no moseq2-index.yaml') metadata['groups'] = {key: i for i, key in enumerate(data_dict)} else: - warnings.warn('groups not loaded from h5 file. Assigning each session to unique group') + warnings.warn('groups not loaded from h5 file. Assigning each session to unique group if no moseq2-index.yaml') metadata['groups'] = {key: i for i, key in enumerate(data_dict)} else: raise IOError('Could not load data from h5 file') From 6cd0b24b9a88496306b2144b7c0b352e44908e78 Mon Sep 17 00:00:00 2001 From: Sherry Date: Wed, 29 Mar 2023 19:33:00 -0400 Subject: [PATCH 04/29] chore: add docs and update index file behavior --- moseq2_model/cli.py | 3 ++- moseq2_model/helpers/wrappers.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/moseq2_model/cli.py b/moseq2_model/cli.py index a861bbb..e4162d8 100644 --- a/moseq2_model/cli.py +++ b/moseq2_model/cli.py @@ -73,7 +73,7 @@ def modeling_parameters(function): @click.option("--kappa", "-k", type=float, default=None, help="Kappa; hyperparameter used to set syllable duration. Larger k = longer syllable lengths") @click.option("--checkpoint-freq", type=int, default=-1, help='save model checkpoint every n iterations') @click.option("--use-checkpoint", is_flag=True, help='indicate whether to use previously saved checkpoint') -@click.option("--index", "-i", type=click.Path(), default="", help="Path to moseq2-index.yaml for group definitions (used only with the separate-trans flag)") +@click.option("--index", "-i", type=click.Path(), default="", help="Path to moseq2-index.yaml for group definitions") @click.option("--default-group", type=str, default="n/a", help="Default group name to use for separate-trans") @click.option("--verbose", '-v', is_flag=True, help="Print syllable log-likelihoods during training.") def learn_model(input_file, dest_file, **config_data): @@ -87,6 +87,7 @@ def learn_model(input_file, dest_file, **config_data): @click.argument("pc_file", type=click.Path(exists=True)) @click.argument("dest_file", type=click.Path(file_okay=True, writable=True, resolve_path=True)) @click.option("--var-name", type=str, default='scores', help="Variable name in input file with PCs") +@click.option("--index", "-i", type=click.Path(), default="", help="Path to moseq2-index.yaml for group definitions") @click.option("--load-groups", type=bool, default=True, help="If groups should be loaded with the PC scores.") def apply_model(model_file, pc_file, dest_file, **config_data): # Apply the ARHMM model located in MODEL_FILE to the PC scores in PC_FILE, and saves the results to DEST_FILE diff --git a/moseq2_model/helpers/wrappers.py b/moseq2_model/helpers/wrappers.py index 00b2b62..7449ee1 100644 --- a/moseq2_model/helpers/wrappers.py +++ b/moseq2_model/helpers/wrappers.py @@ -55,12 +55,13 @@ def learn_model_wrapper(input_file, dest_file, config_data): load_groups=config_data['load_groups']) # Parse index file and update metadata information; namely groups + # If no group data in pca data, use group info from index file select_groups = config_data.get('select_groups', False) index_data, data_metadata = process_indexfile(config_data.get('index', None), data_metadata, config_data['default_group'], select_groups) # Get keys to include in training set - # If no group data in pca data, use group info from index file + # TODO: select_groups not implemented if index_data is not None: data_dict, data_metadata = select_data_to_model(index_data, data_dict, data_metadata, select_groups) From 671d290903178f2c969d0aca7ea15c899467bf47 Mon Sep 17 00:00:00 2001 From: Sherry Date: Wed, 29 Mar 2023 19:45:04 -0400 Subject: [PATCH 05/29] chore: add group from index file --- moseq2_model/helpers/wrappers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/moseq2_model/helpers/wrappers.py b/moseq2_model/helpers/wrappers.py index 7449ee1..1b94d21 100644 --- a/moseq2_model/helpers/wrappers.py +++ b/moseq2_model/helpers/wrappers.py @@ -189,6 +189,10 @@ def apply_model_wrapper(model_file, pc_file, dest_file, config_data): data_dict, data_metadata = load_pcs(filename=pc_file, var_name=config_data.get('var_name', 'scores'), npcs=model_data['run_parameters']['npcs'], load_groups=config_data.get('load_groups', False)) + # parse group information from index file + index_data, data_metadata = process_indexfile(config_data.get('index', None), data_metadata, + config_data.get('default_group', 'n/a'), select_groups=False) + # Apply model syllables = apply_model(model_data['model'], model_data['whitening_parameters'], data_dict, data_metadata) From 0ea762107862e9b8e8bc6784ab4b908ea07b651c Mon Sep 17 00:00:00 2001 From: Sherry Date: Wed, 29 Mar 2023 20:33:11 -0400 Subject: [PATCH 06/29] chore: remove type and nametuple import --- moseq2_model/train/util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/moseq2_model/train/util.py b/moseq2_model/train/util.py index 0e716cf..75eeed3 100644 --- a/moseq2_model/train/util.py +++ b/moseq2_model/train/util.py @@ -1,12 +1,11 @@ """ ARHMM utility functions """ -from typing import Union import numpy as np from tqdm.auto import tqdm from functools import partial from cytoolz import valmap, itemmap -from collections import OrderedDict, defaultdict, namedtuple +from collections import OrderedDict, defaultdict from moseq2_model.util import save_arhmm_checkpoint, get_loglikelihoods @@ -158,7 +157,7 @@ def get_labels_from_model(model): WhiteningParams = namedtuple('WhiteningParams', ['mu', 'L', 'offset']) -def apply_model(model, whitening_params: Union[WhiteningParams, dict], data_dict, metadata): +def apply_model(model, whitening_params, data_dict, metadata): ''' Apply trained model to data_dict. Note that this function might produce unexpected behavior if the model was trained using separate transition matrices for different groups of sessions. From 84db547a04650643deaad4ccaa224749869b644e Mon Sep 17 00:00:00 2001 From: Sherry Date: Wed, 29 Mar 2023 20:59:36 -0400 Subject: [PATCH 07/29] chore: change model_data to dict data type --- moseq2_model/train/util.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/moseq2_model/train/util.py b/moseq2_model/train/util.py index 75eeed3..9f59d75 100644 --- a/moseq2_model/train/util.py +++ b/moseq2_model/train/util.py @@ -154,10 +154,7 @@ def get_labels_from_model(model): return labels -WhiteningParams = namedtuple('WhiteningParams', ['mu', 'L', 'offset']) - - -def apply_model(model, whitening_params, data_dict, metadata): +def apply_model(model, whitening_params, data_dict, metadata, whiten='all'): ''' Apply trained model to data_dict. Note that this function might produce unexpected behavior if the model was trained using separate transition matrices for different groups of sessions. @@ -171,12 +168,14 @@ def apply_model(model, whitening_params, data_dict, metadata): Returns: labels (dict): dictionary of labels predicted per session after modeling ''' - if isinstance(whitening_params, dict): + + # check for whiten parameters to see if whiten_all or whiten_each + if whiten.lower[0].lower() == 'e': # this approach is not recommended, but supported - center = whitening_params[list(whitening_params)[0]].offset == 0 + center = whitening_params[list(whitening_params)[0]]['offset'] == 0 whitened_data, _ = whiten_each(data_dict, center) else: - whitened_data = valmap(lambda x: apply_whitening(x, whitening_params.mu, whitening_params.L, whitening_params.offset), data_dict) + whitened_data = valmap(lambda x: apply_whitening(x, whitening_params['mu'], whitening_params['L'], whitening_params['offset']), data_dict) # apply model to data if 'SeparateTrans' in type(model): @@ -226,7 +225,8 @@ def whiten_all(data_dict, center=True): offset = 0. if center else mu apply_whitening = partial(apply_whitening, mu=mu, L=L, offset=offset) - return OrderedDict((k, contig(apply_whitening(v))) for k, v in data_dict.items()), WhiteningParams(mu, L, offset) + whitening_parameters = {'mu': mu, 'L': L, 'offset': offset} + return OrderedDict((k, contig(apply_whitening(v))) for k, v in data_dict.items()), whitening_parameters # taken from moseq by @mattjj and @alexbw From 33f6c97eaadd1189c1d7f66c8b30572c612b559f Mon Sep 17 00:00:00 2001 From: Sherry Date: Wed, 29 Mar 2023 21:47:44 -0400 Subject: [PATCH 08/29] chore: move apply_whitening back to lambda function --- moseq2_model/train/util.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/moseq2_model/train/util.py b/moseq2_model/train/util.py index 9f59d75..89bd675 100644 --- a/moseq2_model/train/util.py +++ b/moseq2_model/train/util.py @@ -169,6 +169,8 @@ def apply_model(model, whitening_params, data_dict, metadata, whiten='all'): labels (dict): dictionary of labels predicted per session after modeling ''' + # whiten data function + apply_whitening = lambda x: np.linalg.solve(L, (x-mu).T).T + offset # check for whiten parameters to see if whiten_all or whiten_each if whiten.lower[0].lower() == 'e': # this approach is not recommended, but supported @@ -224,7 +226,7 @@ def whiten_all(data_dict, center=True): L = np.linalg.cholesky(Sigma) offset = 0. if center else mu - apply_whitening = partial(apply_whitening, mu=mu, L=L, offset=offset) + apply_whitening = lambda x: np.linalg.solve(L, (x-mu).T).T + offset whitening_parameters = {'mu': mu, 'L': L, 'offset': offset} return OrderedDict((k, contig(apply_whitening(v))) for k, v in data_dict.items()), whitening_parameters From 0e82e372d2586c0ab9e7a876dd591b667730ed2e Mon Sep 17 00:00:00 2001 From: Sherry Date: Thu, 30 Mar 2023 12:37:50 -0400 Subject: [PATCH 09/29] chore: remove unused apply_whitening function --- moseq2_model/train/util.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/moseq2_model/train/util.py b/moseq2_model/train/util.py index 89bd675..67df054 100644 --- a/moseq2_model/train/util.py +++ b/moseq2_model/train/util.py @@ -189,21 +189,6 @@ def apply_model(model, whitening_params, data_dict, metadata, whiten='all'): return labels -def apply_whitening(data, mu, L, offset=0): - '''Apply whitening to data. - - Args: - data (np.ndarray): data to be whitened - mu (np.ndarray): mean of data - L (np.ndarray): Cholesky decomposition of covariance matrix - offset (float): offset to add to whitened data - - Returns: - data (np.ndarray): whitened data - ''' - - return np.linalg.solve(L, (data - mu).T).T + offset - # taken from moseq by @mattjj and @alexbw def whiten_all(data_dict, center=True): @@ -226,6 +211,7 @@ def whiten_all(data_dict, center=True): L = np.linalg.cholesky(Sigma) offset = 0. if center else mu + # set up function to whiten data apply_whitening = lambda x: np.linalg.solve(L, (x-mu).T).T + offset whitening_parameters = {'mu': mu, 'L': L, 'offset': offset} return OrderedDict((k, contig(apply_whitening(v))) for k, v in data_dict.items()), whitening_parameters From 22ab4d7dba8edb8d417ee915de4fb5a0d4982dab Mon Sep 17 00:00:00 2001 From: Sherry Date: Thu, 30 Mar 2023 16:57:03 -0400 Subject: [PATCH 10/29] fix: typo in lower --- moseq2_model/train/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moseq2_model/train/util.py b/moseq2_model/train/util.py index 67df054..19c5096 100644 --- a/moseq2_model/train/util.py +++ b/moseq2_model/train/util.py @@ -172,7 +172,7 @@ def apply_model(model, whitening_params, data_dict, metadata, whiten='all'): # whiten data function apply_whitening = lambda x: np.linalg.solve(L, (x-mu).T).T + offset # check for whiten parameters to see if whiten_all or whiten_each - if whiten.lower[0].lower() == 'e': + if whiten[0].lower() == 'e': # this approach is not recommended, but supported center = whitening_params[list(whitening_params)[0]]['offset'] == 0 whitened_data, _ = whiten_each(data_dict, center) From 73658460d6be3b262439e395c57ad829052ffa3f Mon Sep 17 00:00:00 2001 From: Sherry Date: Thu, 30 Mar 2023 16:57:27 -0400 Subject: [PATCH 11/29] fix: change whitening function --- moseq2_model/train/util.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/moseq2_model/train/util.py b/moseq2_model/train/util.py index 19c5096..6c0d06f 100644 --- a/moseq2_model/train/util.py +++ b/moseq2_model/train/util.py @@ -170,14 +170,19 @@ def apply_model(model, whitening_params, data_dict, metadata, whiten='all'): ''' # whiten data function - apply_whitening = lambda x: np.linalg.solve(L, (x-mu).T).T + offset + try: + mu, L, offset = whitening_params['mu'], whitening_params['L'], whitening_params['offset'] + apply_whitening = lambda x: np.linalg.solve(L, (x-mu).T).T + offset + except: + print('Whitening parameters not found.') + # check for whiten parameters to see if whiten_all or whiten_each if whiten[0].lower() == 'e': # this approach is not recommended, but supported center = whitening_params[list(whitening_params)[0]]['offset'] == 0 whitened_data, _ = whiten_each(data_dict, center) else: - whitened_data = valmap(lambda x: apply_whitening(x, whitening_params['mu'], whitening_params['L'], whitening_params['offset']), data_dict) + whitened_data = valmap(lambda x: apply_whitening(x), data_dict) # apply model to data if 'SeparateTrans' in type(model): From 726d37c0c86a69f66d285dfe2d48bef1feaa1792 Mon Sep 17 00:00:00 2001 From: Sherry Date: Fri, 31 Mar 2023 13:19:17 -0400 Subject: [PATCH 12/29] fix: save new model to similar format as old model --- moseq2_model/helpers/wrappers.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/moseq2_model/helpers/wrappers.py b/moseq2_model/helpers/wrappers.py index 1b94d21..5952a12 100644 --- a/moseq2_model/helpers/wrappers.py +++ b/moseq2_model/helpers/wrappers.py @@ -6,6 +6,7 @@ import sys import glob import click +import numpy as np from copy import deepcopy from moseq2_model.train.util import train_model, run_e_step, apply_model from os.path import join, basename, realpath, dirname, splitext @@ -194,10 +195,25 @@ def apply_model_wrapper(model_file, pc_file, dest_file, config_data): config_data.get('default_group', 'n/a'), select_groups=False) # Apply model - syllables = apply_model(model_data['model'], model_data['whitening_parameters'], data_dict, data_metadata) + syllables = apply_model(model_data['model'], model_data['whitening_parameters'], data_dict, data_metadata, model_data['run_parameters']['whiten']) + + # add -5 padding to the list of states + nlags = model_data['run_parameters'].get('nlags', 3) + for key in syllables.keys(): + syllables[key] = np.append(np.repeat(-5, nlags), syllables[key]) + + # prepare model data dictionary to save + applied_model_data = {} + applied_model_data['labels'] = list(syllables.values()) + applied_model_data['keys'] = list(syllables.keys()) + applied_model_data['model_parameters'] = model_data['model_parameters'] + applied_model_data['oracle_run_parameters'] = model_data['run_parameters'] + applied_model_data['metadata'] = data_metadata + applied_model_data['model'] = model_data['model'] + applied_model_data['whitening_parameters'] = model_data['whitening_parameters'] # Save output - save_dict(filename=dest_file, obj_to_save=syllables) + save_dict(filename=dest_file, obj_to_save=applied_model_data) def kappa_scan_fit_models_wrapper(input_file, config_data, output_dir): From f4a38cf3fc26d6d9b945ecc53e08cfb688939b66 Mon Sep 17 00:00:00 2001 From: Sherry Date: Fri, 31 Mar 2023 13:48:23 -0400 Subject: [PATCH 13/29] fix: cast model type to str --- moseq2_model/train/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moseq2_model/train/util.py b/moseq2_model/train/util.py index 6c0d06f..c7140d5 100644 --- a/moseq2_model/train/util.py +++ b/moseq2_model/train/util.py @@ -185,7 +185,7 @@ def apply_model(model, whitening_params, data_dict, metadata, whiten='all'): whitened_data = valmap(lambda x: apply_whitening(x), data_dict) # apply model to data - if 'SeparateTrans' in type(model): + if 'SeparateTrans' in str(type(model)): # not recommended, but supported labels = itemmap(lambda item: (item[0], model.heldout_viterbi(item[1], group_id=metadata['groups'][item[0]])), whitened_data) else: From 1834900931b20e2247e9a9ea7883db66085a5dcf Mon Sep 17 00:00:00 2001 From: Sherry Date: Fri, 31 Mar 2023 14:39:43 -0400 Subject: [PATCH 14/29] chore: fix the tests --- tests/integration_tests/test_data_helper.py | 8 ++++---- tests/unit_tests/test_train_models.py | 2 +- tests/unit_tests/test_train_utils.py | 16 ++++++++-------- tests/unit_tests/test_util.py | 4 ++-- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/integration_tests/test_data_helper.py b/tests/integration_tests/test_data_helper.py index b0ea783..bc94784 100644 --- a/tests/integration_tests/test_data_helper.py +++ b/tests/integration_tests/test_data_helper.py @@ -76,7 +76,7 @@ def test_prepare_model_metadata(self): data_dict, data_metadata = select_data_to_model(index_data, data_dict, data_metadata) - data_dict1, model_parameters, train_list, hold_out_list = \ + data_dict1, model_parameters, train_list, hold_out_list, whitening_parameters = \ prepare_model_metadata(data_dict, data_metadata, config_data) assert data_dict.values() != data_dict1.values(), "Index loaded uuids and training data does not match scores file" @@ -85,7 +85,7 @@ def test_prepare_model_metadata(self): config_data['whiten'] = 'each' config_data['noise_level'] = 1 - data_dict1, model_parameters, train_list, hold_out_list = \ + data_dict1, model_parameters, train_list, hold_out_list, whitening_parameters = \ prepare_model_metadata(data_dict, data_metadata, config_data) assert data_dict.values() != data_dict1.values(), "Index loaded uuids and training data does not match scores file" @@ -93,7 +93,7 @@ def test_prepare_model_metadata(self): assert hold_out_list == [], "Some of the data is unintentionally held out" config_data['whiten'] = 'none' - data_dict1, model_parameters, train_list, hold_out_list = \ + data_dict1, model_parameters, train_list, hold_out_list, whitening_parameters = \ prepare_model_metadata(data_dict, data_metadata, config_data) assert data_dict.values() != data_dict1.values(), "Index loaded uuids and training data does not match scores file" @@ -121,7 +121,7 @@ def test_get_heldout_data_splits(self): data_dict, data_metadata = select_data_to_model(index_data, data_dict, data_metadata) - data_dict, model_parameters, train_list, hold_out_list = \ + data_dict, model_parameters, train_list, hold_out_list, whitening_parameters = \ prepare_model_metadata(data_dict, data_metadata, config_data) assert (sorted(train_list) != sorted(hold_out_list)), "Training data is the same as held out data" diff --git a/tests/unit_tests/test_train_models.py b/tests/unit_tests/test_train_models.py index 054103f..26fdb5a 100644 --- a/tests/unit_tests/test_train_models.py +++ b/tests/unit_tests/test_train_models.py @@ -60,7 +60,7 @@ def test_ARHMM(self): nkeys = 5 all_keys = ['key1', 'key2', 'key3', 'key4', 'key5'] - data_dict, model_parameters, train_list, hold_out_list = \ + data_dict, model_parameters, train_list, hold_out_list, whitening_parameters = \ prepare_model_metadata(data_dict, data_metadata, config_data) model_parameters['separate_trans'] = False diff --git a/tests/unit_tests/test_train_utils.py b/tests/unit_tests/test_train_utils.py index ed4a597..40616ac 100644 --- a/tests/unit_tests/test_train_utils.py +++ b/tests/unit_tests/test_train_utils.py @@ -28,7 +28,7 @@ def get_model(separate_trans=False, robust=False, groups=[]): load_groups=True) data_metadata['groups'] = {k: g for k, g in zip(data_dict, groups)} - data_dict, model_parameters, _, _ = \ + data_dict, model_parameters, _, _, whitening_parameters = \ prepare_model_metadata(data_dict, data_metadata, config_data) arhmm = ARHMM(data_dict=data_dict, **model_parameters) @@ -45,7 +45,7 @@ def test_train_model(self): model, data_dict = get_model() - X = whiten_all(data_dict) + X, whitening_parameters = whiten_all(data_dict) training_data, validation_data = get_training_data_splits(config_data['percent_split'] / 100, X) model, lls, labels, iter_lls, iter_holls, _ = train_model(model, num_iter=5, train_data=training_data, @@ -68,7 +68,7 @@ def test_train_model(self): model, data_dict = get_model(separate_trans=True, groups=['default', 'Group1']) - X = whiten_all(data_dict) + X, whitening_parameters = whiten_all(data_dict) training_data, validation_data = get_training_data_splits(config_data['percent_split'] / 100, X) model, lls, labels, iter_lls, iter_holls, _ = train_model(model, num_iter=5, train_data=training_data, @@ -91,7 +91,7 @@ def test_get_labels_from_model(self): model, data_dict = get_model() - X = whiten_all(data_dict) + X, whitening_parameters = whiten_all(data_dict) training_data, validation_data = get_training_data_splits(config_data['percent_split'] / 100, X) model, lls, labels, iter_lls, iter_holls, _ = train_model(model, num_iter=5, train_data=training_data, @@ -105,16 +105,16 @@ def test_whiten_all(self): _, data_dict = get_model() - whitened_a = whiten_all(data_dict) - whitened_e = whiten_each(data_dict) + whitened_a, whitening_parameters = whiten_all(data_dict) + whitened_e, whitening_parameters = whiten_each(data_dict) assert data_dict.values() != whitened_a.values() assert whitened_a.values() != whitened_e.values() def test_whiten_each(self): _, data_dict = get_model() - whitened_a = whiten_all(data_dict, center=False) - whitened_e = whiten_each(data_dict, center=False) + whitened_a, whitening_parameters = whiten_all(data_dict, center=False) + whitened_e, whitening_parameters = whiten_each(data_dict, center=False) assert data_dict.values() != whitened_a.values() assert whitened_a.values() != whitened_e.values() diff --git a/tests/unit_tests/test_util.py b/tests/unit_tests/test_util.py index 30cbd98..87baa11 100644 --- a/tests/unit_tests/test_util.py +++ b/tests/unit_tests/test_util.py @@ -309,7 +309,7 @@ def test_copy_model(self): with open(config_file, 'r') as f: config_data = yaml.safe_load(f) - X = whiten_all(data_dict) + X, whitening_parameters = whiten_all(data_dict) training_data, validation_data = get_training_data_splits(config_data['percent_split'] / 100, X) model, lls, labels, iter_lls, iter_holls, _ = train_model(model, num_iter=5, train_data=training_data, @@ -335,7 +335,7 @@ def check_params(model, params): with open(config_file, 'r') as f: config_data = yaml.safe_load(f) - X = whiten_all(data_dict) + X, whitening_parameters = whiten_all(data_dict) training_data, validation_data = get_training_data_splits(config_data['percent_split'] / 100, X) model, lls, labels, iter_lls, iter_holls, _ = train_model(model, num_iter=5, train_data=training_data, From 69c6a55440f8e3a3ae97b1a2cd572389f56e49b1 Mon Sep 17 00:00:00 2001 From: Sherry Date: Sat, 1 Apr 2023 13:48:55 -0400 Subject: [PATCH 15/29] chore: add functions to make dir for new models --- moseq2_model/helpers/wrappers.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/moseq2_model/helpers/wrappers.py b/moseq2_model/helpers/wrappers.py index 5952a12..6c2a944 100644 --- a/moseq2_model/helpers/wrappers.py +++ b/moseq2_model/helpers/wrappers.py @@ -183,6 +183,14 @@ def apply_model_wrapper(model_file, pc_file, dest_file, config_data): Returns: None """ + + assert splitext(basename(dest_file))[-1] in ['.mat', '.z', '.pkl', '.p', '.h5'], 'Incorrect model filetype' + os.makedirs(dirname(dest_file), exist_ok=True) + + if not os.access(dirname(dest_file), os.W_OK): + raise IOError('Output directory is not writable.') + + # Load model model_data = load_dict(model_file) From 216ae65584f6779f8e35eb5ad456d1529e6ec952 Mon Sep 17 00:00:00 2001 From: Sherry Date: Sat, 1 Apr 2023 13:49:21 -0400 Subject: [PATCH 16/29] chore: add apply_model_command for notebook --- moseq2_model/gui.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/moseq2_model/gui.py b/moseq2_model/gui.py index 1fadfe2..9ce0944 100644 --- a/moseq2_model/gui.py +++ b/moseq2_model/gui.py @@ -5,7 +5,7 @@ import ruamel.yaml as yaml from moseq2_model.cli import learn_model, kappa_scan_fit_models from os.path import dirname, join, exists -from moseq2_model.helpers.wrappers import learn_model_wrapper, kappa_scan_fit_models_wrapper +from moseq2_model.helpers.wrappers import learn_model_wrapper, kappa_scan_fit_models_wrapper, apply_model_wrapper def learn_model_command(progress_paths, get_cmd=True, verbose=False): """ @@ -60,4 +60,18 @@ def learn_model_command(progress_paths, get_cmd=True, verbose=False): command = kappa_scan_fit_models_wrapper(input_file, config_data, output_dir) return command else: - learn_model_wrapper(input_file, dest_file, config_data) \ No newline at end of file + learn_model_wrapper(input_file, dest_file, config_data) + + +def apply_model_command(progress_paths, model_file): + # Load proper input variables + pc_file = progress_paths['scores_path'] + dest_file = progress_paths['model_path'] + config_file = progress_paths['config_file'] + index = progress_paths['index_file'] + output_dir = progress_paths['base_model_path'] + + with open(config_file, 'r') as f: + config_data = yaml.safe_load(f) + + apply_model_wrapper(model_file, pc_file, dest_file, config_data) \ No newline at end of file From 34a867dd11c2aff52c3feb0838210bacacf9b82c Mon Sep 17 00:00:00 2001 From: Sherry Date: Sat, 1 Apr 2023 13:53:49 -0400 Subject: [PATCH 17/29] chore: add docstring and comments --- moseq2_model/gui.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/moseq2_model/gui.py b/moseq2_model/gui.py index 9ce0944..af02169 100644 --- a/moseq2_model/gui.py +++ b/moseq2_model/gui.py @@ -64,6 +64,13 @@ def learn_model_command(progress_paths, get_cmd=True, verbose=False): def apply_model_command(progress_paths, model_file): + """Apply a pre-trained ARHMM model to a new dataset from within a Jupyter notebook. + + Args: + progress_paths (dict): notebook progress dict that contains paths to the pc scores, config, and index files. + model_file (str): path to the pre-trained ARHMM model. + """ + # Load proper input variables pc_file = progress_paths['scores_path'] dest_file = progress_paths['model_path'] @@ -71,7 +78,9 @@ def apply_model_command(progress_paths, model_file): index = progress_paths['index_file'] output_dir = progress_paths['base_model_path'] + # load config data with open(config_file, 'r') as f: config_data = yaml.safe_load(f) + # apply model to data apply_model_wrapper(model_file, pc_file, dest_file, config_data) \ No newline at end of file From 6ad0bb0e20c67b6c3e9f33da029062fb49a98ccf Mon Sep 17 00:00:00 2001 From: Sherry Date: Sat, 1 Apr 2023 14:01:25 -0400 Subject: [PATCH 18/29] chore: fix arhmm model to arhmm --- moseq2_model/cli.py | 4 ++-- moseq2_model/gui.py | 4 ++-- moseq2_model/train/models.py | 4 ++-- moseq2_model/train/util.py | 2 +- moseq2_model/util.py | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/moseq2_model/cli.py b/moseq2_model/cli.py index e4162d8..aa47fd6 100644 --- a/moseq2_model/cli.py +++ b/moseq2_model/cli.py @@ -82,7 +82,7 @@ def learn_model(input_file, dest_file, **config_data): learn_model_wrapper(input_file, dest_file, config_data) -@cli.command(name='apply-model', help='Apply trained ARHMM model to PC scores.') +@cli.command(name='apply-model', help='Apply trained ARHMM to PC scores.') @click.argument("model_file", type=click.Path(exists=True)) @click.argument("pc_file", type=click.Path(exists=True)) @click.argument("dest_file", type=click.Path(file_okay=True, writable=True, resolve_path=True)) @@ -90,7 +90,7 @@ def learn_model(input_file, dest_file, **config_data): @click.option("--index", "-i", type=click.Path(), default="", help="Path to moseq2-index.yaml for group definitions") @click.option("--load-groups", type=bool, default=True, help="If groups should be loaded with the PC scores.") def apply_model(model_file, pc_file, dest_file, **config_data): - # Apply the ARHMM model located in MODEL_FILE to the PC scores in PC_FILE, and saves the results to DEST_FILE + # Apply the ARHMM located in MODEL_FILE to the PC scores in PC_FILE, and saves the results to DEST_FILE apply_model_wrapper(model_file, pc_file, dest_file, config_data) diff --git a/moseq2_model/gui.py b/moseq2_model/gui.py index af02169..5464982 100644 --- a/moseq2_model/gui.py +++ b/moseq2_model/gui.py @@ -64,11 +64,11 @@ def learn_model_command(progress_paths, get_cmd=True, verbose=False): def apply_model_command(progress_paths, model_file): - """Apply a pre-trained ARHMM model to a new dataset from within a Jupyter notebook. + """Apply a trained ARHMM to a new dataset from within a Jupyter notebook. Args: progress_paths (dict): notebook progress dict that contains paths to the pc scores, config, and index files. - model_file (str): path to the pre-trained ARHMM model. + model_file (str): path to the pre-trained ARHMM. """ # Load proper input variables diff --git a/moseq2_model/train/models.py b/moseq2_model/train/models.py index 6b3898a..857ca7b 100644 --- a/moseq2_model/train/models.py +++ b/moseq2_model/train/models.py @@ -1,5 +1,5 @@ """ -ARHMM model initialization utilities. +ARHMM initialization utilities. """ import warnings @@ -54,7 +54,7 @@ def ARHMM(data_dict, kappa=1e6, gamma=999, nlags=3, alpha=5.7, affine=True, model_hypparams={}, obs_hypparams={}, sticky_init=False, separate_trans=False, groups=None, robust=False, silent=False): """ - Initialize ARHMM and add data and group labels to the ARHMM model. + Initialize ARHMM and add data and group labels to the ARHMM. Args: data_dict (OrderedDict): training data to add to model diff --git a/moseq2_model/train/util.py b/moseq2_model/train/util.py index c7140d5..6ab271e 100644 --- a/moseq2_model/train/util.py +++ b/moseq2_model/train/util.py @@ -144,7 +144,7 @@ def get_labels_from_model(model): Grab model labels for each training dataset and place them in a list. Args: - model (ARHMM): trained ARHMM model + model (ARHMM): trained ARHMM Returns: labels (list): An array of predicted syllable labels for each training session diff --git a/moseq2_model/util.py b/moseq2_model/util.py index e8fb2a2..5ba3ac3 100644 --- a/moseq2_model/util.py +++ b/moseq2_model/util.py @@ -165,7 +165,7 @@ def get_loglikelihoods(arhmm, data, groups, separate_trans, normalize=True): Compute the log-likelihoods of the training sessions. Args: - arhmm (ARHMM): ARHMM model. + arhmm (ARHMM): the ARHMM model object. data (dict): dict object with UUID keys containing the PCS used for training. groups (list): list of assigned groups for all corresponding session uuids. separate_trans (bool): flag to compute separate log-likelihoods for each modeled group. @@ -319,7 +319,7 @@ def save_arhmm_checkpoint(filename: str, arhmm: dict): Args: filename (str): path that specifies the checkpoint - arhmm (dict): a dictionary containing the arhmm model object, training iteration number, log-likelihoods of each training step, and labels for each step. + arhmm (dict): a dictionary containing the arhmm object, training iteration number, log-likelihoods of each training step, and labels for each step. """ # Getting model object From 9ce8960044803be2570ae3713c54c59eb14390b4 Mon Sep 17 00:00:00 2001 From: Sherry Date: Mon, 3 Apr 2023 13:41:35 -0400 Subject: [PATCH 19/29] chore: change trained to pre-trained --- moseq2_model/cli.py | 2 +- moseq2_model/gui.py | 2 +- moseq2_model/helpers/wrappers.py | 4 ++-- moseq2_model/train/util.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/moseq2_model/cli.py b/moseq2_model/cli.py index aa47fd6..8bce4af 100644 --- a/moseq2_model/cli.py +++ b/moseq2_model/cli.py @@ -82,7 +82,7 @@ def learn_model(input_file, dest_file, **config_data): learn_model_wrapper(input_file, dest_file, config_data) -@cli.command(name='apply-model', help='Apply trained ARHMM to PC scores.') +@cli.command(name='apply-model', help='Apply pre-trained ARHMM to PC scores.') @click.argument("model_file", type=click.Path(exists=True)) @click.argument("pc_file", type=click.Path(exists=True)) @click.argument("dest_file", type=click.Path(file_okay=True, writable=True, resolve_path=True)) diff --git a/moseq2_model/gui.py b/moseq2_model/gui.py index 5464982..7c67fac 100644 --- a/moseq2_model/gui.py +++ b/moseq2_model/gui.py @@ -64,7 +64,7 @@ def learn_model_command(progress_paths, get_cmd=True, verbose=False): def apply_model_command(progress_paths, model_file): - """Apply a trained ARHMM to a new dataset from within a Jupyter notebook. + """Apply a pre-trained ARHMM to a new dataset from within a Jupyter notebook. Args: progress_paths (dict): notebook progress dict that contains paths to the pc scores, config, and index files. diff --git a/moseq2_model/helpers/wrappers.py b/moseq2_model/helpers/wrappers.py index 6c2a944..9a6d23d 100644 --- a/moseq2_model/helpers/wrappers.py +++ b/moseq2_model/helpers/wrappers.py @@ -173,10 +173,10 @@ def learn_model_wrapper(input_file, dest_file, config_data): def apply_model_wrapper(model_file, pc_file, dest_file, config_data): """ - Wrapper function to apply a trained model to new data. + Wrapper function to apply a pre-trained model to new data. Args: - model_file (str): Path to trained model file + model_file (str): Path to pre-trained model file pc_file (str): Path to PC scores file dest_file (str): Path to save output file diff --git a/moseq2_model/train/util.py b/moseq2_model/train/util.py index 6ab271e..1e2c91e 100644 --- a/moseq2_model/train/util.py +++ b/moseq2_model/train/util.py @@ -156,11 +156,11 @@ def get_labels_from_model(model): def apply_model(model, whitening_params, data_dict, metadata, whiten='all'): ''' - Apply trained model to data_dict. Note that this function might produce unexpected behavior + Apply pre-trained model to data_dict. Note that this function might produce unexpected behavior if the model was trained using separate transition matrices for different groups of sessions. Args: - model (ARHMM): trained model + model (ARHMM): pre-trained model whitening_params (namedtuple or dict): whitening parameters data_dict (OrderedDict): data to apply model to metadata (dict): metadata for data_dict From 134aef52109fe30e048fbe30b6f266b2f2b85d59 Mon Sep 17 00:00:00 2001 From: Sherry Date: Mon, 3 Apr 2023 15:43:49 -0400 Subject: [PATCH 20/29] chore: clean up the code with a for loop --- moseq2_model/helpers/wrappers.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/moseq2_model/helpers/wrappers.py b/moseq2_model/helpers/wrappers.py index 9a6d23d..302745d 100644 --- a/moseq2_model/helpers/wrappers.py +++ b/moseq2_model/helpers/wrappers.py @@ -211,14 +211,15 @@ def apply_model_wrapper(model_file, pc_file, dest_file, config_data): syllables[key] = np.append(np.repeat(-5, nlags), syllables[key]) # prepare model data dictionary to save + # save applied model data applied_model_data = {} applied_model_data['labels'] = list(syllables.values()) applied_model_data['keys'] = list(syllables.keys()) - applied_model_data['model_parameters'] = model_data['model_parameters'] - applied_model_data['oracle_run_parameters'] = model_data['run_parameters'] applied_model_data['metadata'] = data_metadata - applied_model_data['model'] = model_data['model'] - applied_model_data['whitening_parameters'] = model_data['whitening_parameters'] + + # copy over pre-trained model data + for key in ['model_parameters', 'run_parameters', 'model', 'whitening_parameters']: + applied_model_data[key] = model_data[key] # Save output save_dict(filename=dest_file, obj_to_save=applied_model_data) From b1edb98e413ac686d7bd89ab44f94ec51defbae3 Mon Sep 17 00:00:00 2001 From: Sherry Date: Tue, 4 Apr 2023 14:27:32 -0400 Subject: [PATCH 21/29] chore: add pc score to model object --- moseq2_model/helpers/wrappers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/moseq2_model/helpers/wrappers.py b/moseq2_model/helpers/wrappers.py index 302745d..6aba94f 100644 --- a/moseq2_model/helpers/wrappers.py +++ b/moseq2_model/helpers/wrappers.py @@ -158,6 +158,7 @@ def learn_model_wrapper(input_file, dest_file, config_data): 'train_ll': train_ll, 'expected_states': expected_states if config_data['e_step'] else None, 'whitening_parameters': whitening_parameters, + 'pc_scores': os.path.abspath(pc_scores) } # Save model From 1e0f6c78458a1bbc6901c2a03ad09a3dc60ca611 Mon Sep 17 00:00:00 2001 From: Sherry Date: Tue, 4 Apr 2023 18:18:57 -0400 Subject: [PATCH 22/29] chore: saving pc_file and model file in the model object --- moseq2_model/helpers/wrappers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/moseq2_model/helpers/wrappers.py b/moseq2_model/helpers/wrappers.py index 6aba94f..042fc37 100644 --- a/moseq2_model/helpers/wrappers.py +++ b/moseq2_model/helpers/wrappers.py @@ -158,7 +158,7 @@ def learn_model_wrapper(input_file, dest_file, config_data): 'train_ll': train_ll, 'expected_states': expected_states if config_data['e_step'] else None, 'whitening_parameters': whitening_parameters, - 'pc_scores': os.path.abspath(pc_scores) + 'pc_scores': os.path.abspath(input_file) } # Save model @@ -217,7 +217,9 @@ def apply_model_wrapper(model_file, pc_file, dest_file, config_data): applied_model_data['labels'] = list(syllables.values()) applied_model_data['keys'] = list(syllables.keys()) applied_model_data['metadata'] = data_metadata - + applied_model_data['pc_scores'] = os.path.abspath(pc_file) + applied_model_data['pre_trained_model'] = os.path.abspath(model_file) + # copy over pre-trained model data for key in ['model_parameters', 'run_parameters', 'model', 'whitening_parameters']: applied_model_data[key] = model_data[key] From 37fb3ca0dcb6923a3ea604cf82ec5cb7620032d1 Mon Sep 17 00:00:00 2001 From: Sherry Date: Tue, 4 Apr 2023 18:20:37 -0400 Subject: [PATCH 23/29] fix: remove vanilla try except --- moseq2_model/train/util.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/moseq2_model/train/util.py b/moseq2_model/train/util.py index 1e2c91e..5e06912 100644 --- a/moseq2_model/train/util.py +++ b/moseq2_model/train/util.py @@ -170,11 +170,8 @@ def apply_model(model, whitening_params, data_dict, metadata, whiten='all'): ''' # whiten data function - try: - mu, L, offset = whitening_params['mu'], whitening_params['L'], whitening_params['offset'] - apply_whitening = lambda x: np.linalg.solve(L, (x-mu).T).T + offset - except: - print('Whitening parameters not found.') + mu, L, offset = whitening_params['mu'], whitening_params['L'], whitening_params['offset'] + apply_whitening = lambda x: np.linalg.solve(L, (x-mu).T).T + offset # check for whiten parameters to see if whiten_all or whiten_each if whiten[0].lower() == 'e': From 8e6b27857d4bfd94e4ef0ee27073a866368b73da Mon Sep 17 00:00:00 2001 From: Sherry Date: Tue, 4 Apr 2023 18:35:33 -0400 Subject: [PATCH 24/29] chore: ensure whitening parameters exist --- moseq2_model/helpers/wrappers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/moseq2_model/helpers/wrappers.py b/moseq2_model/helpers/wrappers.py index 042fc37..a543c7c 100644 --- a/moseq2_model/helpers/wrappers.py +++ b/moseq2_model/helpers/wrappers.py @@ -195,6 +195,9 @@ def apply_model_wrapper(model_file, pc_file, dest_file, config_data): # Load model model_data = load_dict(model_file) + if model_data.get('whitening_parameters') is None: + raise KeyError('Whitening parameters not found in model file. Unable to apply model to new data. Please retrain the model using the latest version.') + # Load PC scores data_dict, data_metadata = load_pcs(filename=pc_file, var_name=config_data.get('var_name', 'scores'), npcs=model_data['run_parameters']['npcs'], load_groups=config_data.get('load_groups', False)) From 3bcb731e84c779cd9bf65f59bcb5b6087728c175 Mon Sep 17 00:00:00 2001 From: Sherry Date: Wed, 5 Apr 2023 12:10:04 -0400 Subject: [PATCH 25/29] chore: rename dictionary --- moseq2_model/helpers/wrappers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/moseq2_model/helpers/wrappers.py b/moseq2_model/helpers/wrappers.py index a543c7c..e017587 100644 --- a/moseq2_model/helpers/wrappers.py +++ b/moseq2_model/helpers/wrappers.py @@ -158,7 +158,7 @@ def learn_model_wrapper(input_file, dest_file, config_data): 'train_ll': train_ll, 'expected_states': expected_states if config_data['e_step'] else None, 'whitening_parameters': whitening_parameters, - 'pc_scores': os.path.abspath(input_file) + 'pc_score_path': os.path.abspath(input_file) } # Save model @@ -220,8 +220,8 @@ def apply_model_wrapper(model_file, pc_file, dest_file, config_data): applied_model_data['labels'] = list(syllables.values()) applied_model_data['keys'] = list(syllables.keys()) applied_model_data['metadata'] = data_metadata - applied_model_data['pc_scores'] = os.path.abspath(pc_file) - applied_model_data['pre_trained_model'] = os.path.abspath(model_file) + applied_model_data['pc_score_path'] = os.path.abspath(pc_file) + applied_model_data['pre_trained_model_path'] = os.path.abspath(model_file) # copy over pre-trained model data for key in ['model_parameters', 'run_parameters', 'model', 'whitening_parameters']: From cf4292ec6bd86471be047306e87c6fcb70ff082d Mon Sep 17 00:00:00 2001 From: Sherry Date: Wed, 5 Apr 2023 12:31:20 -0400 Subject: [PATCH 26/29] chore: simplify with valmap --- moseq2_model/helpers/wrappers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/moseq2_model/helpers/wrappers.py b/moseq2_model/helpers/wrappers.py index e017587..000dcb5 100644 --- a/moseq2_model/helpers/wrappers.py +++ b/moseq2_model/helpers/wrappers.py @@ -8,6 +8,7 @@ import click import numpy as np from copy import deepcopy +from cytoolz import valmap from moseq2_model.train.util import train_model, run_e_step, apply_model from os.path import join, basename, realpath, dirname, splitext from moseq2_model.util import (save_dict, load_pcs, get_parameters_from_model, copy_model, get_scan_range_kappas, @@ -211,8 +212,7 @@ def apply_model_wrapper(model_file, pc_file, dest_file, config_data): # add -5 padding to the list of states nlags = model_data['run_parameters'].get('nlags', 3) - for key in syllables.keys(): - syllables[key] = np.append(np.repeat(-5, nlags), syllables[key]) + syllables = valmap(lambda v: np.concatenate(([-5] * nlags, v)), syllables) # prepare model data dictionary to save # save applied model data From f2e8d0ee3f0045ade58de12ef393a12828516c15 Mon Sep 17 00:00:00 2001 From: Sherry Date: Wed, 5 Apr 2023 12:35:54 -0400 Subject: [PATCH 27/29] chore: simplify whitened_data --- moseq2_model/train/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moseq2_model/train/util.py b/moseq2_model/train/util.py index 5e06912..379e653 100644 --- a/moseq2_model/train/util.py +++ b/moseq2_model/train/util.py @@ -179,7 +179,7 @@ def apply_model(model, whitening_params, data_dict, metadata, whiten='all'): center = whitening_params[list(whitening_params)[0]]['offset'] == 0 whitened_data, _ = whiten_each(data_dict, center) else: - whitened_data = valmap(lambda x: apply_whitening(x), data_dict) + whitened_data = valmap(apply_whitening, data_dict) # apply model to data if 'SeparateTrans' in str(type(model)): From b146e4b875eb851cc2724cdfc1399aa766e6ae9e Mon Sep 17 00:00:00 2001 From: Sherry Date: Wed, 5 Apr 2023 16:42:10 -0400 Subject: [PATCH 28/29] chore: version bump --- moseq2_model/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moseq2_model/__init__.py b/moseq2_model/__init__.py index 218644a..7f53bea 100644 --- a/moseq2_model/__init__.py +++ b/moseq2_model/__init__.py @@ -1 +1 @@ -__version__ = 'v1.1.2' \ No newline at end of file +__version__ = 'v1.2.0' \ No newline at end of file From b8dfb512e019f0878b6f1cf1df622841e1f2152b Mon Sep 17 00:00:00 2001 From: Sherry Date: Wed, 5 Apr 2023 17:36:58 -0400 Subject: [PATCH 29/29] chore: pip install numpy --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 0d21403..2e9e63f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,6 +14,7 @@ jobs: stage: latest-pythons before_install: - pip install -U pip + - pip install numpy==1.18.3 - pip install pytest==5.4.1 codecov pytest-cov - export PYTHONPATH=$PYTHONPATH:$(pwd) install: