diff --git a/.travis.yml b/.travis.yml index 0d21403..2e9e63f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,6 +14,7 @@ jobs: stage: latest-pythons before_install: - pip install -U pip + - pip install numpy==1.18.3 - pip install pytest==5.4.1 codecov pytest-cov - export PYTHONPATH=$PYTHONPATH:$(pwd) install: diff --git a/moseq2_model/__init__.py b/moseq2_model/__init__.py index 218644a..7f53bea 100644 --- a/moseq2_model/__init__.py +++ b/moseq2_model/__init__.py @@ -1 +1 @@ -__version__ = 'v1.1.2' \ No newline at end of file +__version__ = 'v1.2.0' \ No newline at end of file diff --git a/moseq2_model/cli.py b/moseq2_model/cli.py index 864ff5d..8bce4af 100644 --- a/moseq2_model/cli.py +++ b/moseq2_model/cli.py @@ -6,7 +6,7 @@ import click from os.path import join from moseq2_model.util import count_frames as count_frames_wrapper -from moseq2_model.helpers.wrappers import learn_model_wrapper, kappa_scan_fit_models_wrapper +from moseq2_model.helpers.wrappers import learn_model_wrapper, kappa_scan_fit_models_wrapper, apply_model_wrapper orig_init = click.core.Option.__init__ @@ -46,7 +46,7 @@ def modeling_parameters(function): function = click.option('--e-step', is_flag=True, help="Compute the expected state sequence for each recordings")(function) function = click.option("--save-every", "-s", type=int, default=-1, help="Increment to save labels and model object (-1 for just last)")(function) - function = click.option("--save-model", is_flag=True, help="Save model object at the end of training")(function) + function = click.option("--save-model", type=bool, default=True, help="Save model object at the end of training")(function) function = click.option("--max-states", "-m", type=int, default=100, help="Maximum number of states")(function) function = click.option("--npcs", type=int, default=10, help="Number of PCs to use")(function) function = click.option("--whiten", "-w", type=str, default='all', help="Whiten PCs: (e)each session (a)ll combined or (n)o whitening")(function) @@ -73,7 +73,7 @@ def modeling_parameters(function): @click.option("--kappa", "-k", type=float, default=None, help="Kappa; hyperparameter used to set syllable duration. Larger k = longer syllable lengths") @click.option("--checkpoint-freq", type=int, default=-1, help='save model checkpoint every n iterations') @click.option("--use-checkpoint", is_flag=True, help='indicate whether to use previously saved checkpoint') -@click.option("--index", "-i", type=click.Path(), default="", help="Path to moseq2-index.yaml for group definitions (used only with the separate-trans flag)") +@click.option("--index", "-i", type=click.Path(), default="", help="Path to moseq2-index.yaml for group definitions") @click.option("--default-group", type=str, default="n/a", help="Default group name to use for separate-trans") @click.option("--verbose", '-v', is_flag=True, help="Print syllable log-likelihoods during training.") def learn_model(input_file, dest_file, **config_data): @@ -81,6 +81,20 @@ def learn_model(input_file, dest_file, **config_data): learn_model_wrapper(input_file, dest_file, config_data) + +@cli.command(name='apply-model', help='Apply pre-trained ARHMM to PC scores.') +@click.argument("model_file", type=click.Path(exists=True)) +@click.argument("pc_file", type=click.Path(exists=True)) +@click.argument("dest_file", type=click.Path(file_okay=True, writable=True, resolve_path=True)) +@click.option("--var-name", type=str, default='scores', help="Variable name in input file with PCs") +@click.option("--index", "-i", type=click.Path(), default="", help="Path to moseq2-index.yaml for group definitions") +@click.option("--load-groups", type=bool, default=True, help="If groups should be loaded with the PC scores.") +def apply_model(model_file, pc_file, dest_file, **config_data): + # Apply the ARHMM located in MODEL_FILE to the PC scores in PC_FILE, and saves the results to DEST_FILE + + apply_model_wrapper(model_file, pc_file, dest_file, config_data) + + @cli.command(name='kappa-scan', help='Batch train multiple model to scan over different kappa values.') @click.argument('input_file', type=click.Path(exists=True)) @click.argument('output_dir', type=click.Path(exists=False)) diff --git a/moseq2_model/gui.py b/moseq2_model/gui.py index 1fadfe2..7c67fac 100644 --- a/moseq2_model/gui.py +++ b/moseq2_model/gui.py @@ -5,7 +5,7 @@ import ruamel.yaml as yaml from moseq2_model.cli import learn_model, kappa_scan_fit_models from os.path import dirname, join, exists -from moseq2_model.helpers.wrappers import learn_model_wrapper, kappa_scan_fit_models_wrapper +from moseq2_model.helpers.wrappers import learn_model_wrapper, kappa_scan_fit_models_wrapper, apply_model_wrapper def learn_model_command(progress_paths, get_cmd=True, verbose=False): """ @@ -60,4 +60,27 @@ def learn_model_command(progress_paths, get_cmd=True, verbose=False): command = kappa_scan_fit_models_wrapper(input_file, config_data, output_dir) return command else: - learn_model_wrapper(input_file, dest_file, config_data) \ No newline at end of file + learn_model_wrapper(input_file, dest_file, config_data) + + +def apply_model_command(progress_paths, model_file): + """Apply a pre-trained ARHMM to a new dataset from within a Jupyter notebook. + + Args: + progress_paths (dict): notebook progress dict that contains paths to the pc scores, config, and index files. + model_file (str): path to the pre-trained ARHMM. + """ + + # Load proper input variables + pc_file = progress_paths['scores_path'] + dest_file = progress_paths['model_path'] + config_file = progress_paths['config_file'] + index = progress_paths['index_file'] + output_dir = progress_paths['base_model_path'] + + # load config data + with open(config_file, 'r') as f: + config_data = yaml.safe_load(f) + + # apply model to data + apply_model_wrapper(model_file, pc_file, dest_file, config_data) \ No newline at end of file diff --git a/moseq2_model/helpers/data.py b/moseq2_model/helpers/data.py index 06aba2e..7a47540 100644 --- a/moseq2_model/helpers/data.py +++ b/moseq2_model/helpers/data.py @@ -174,12 +174,15 @@ def prepare_model_metadata(data_dict, data_metadata, config_data): model_parameters['groups'] = {k: data_metadata['groups'][k] for k in train} # Whiten the data + whitening_parameters = None if config_data['whiten'][0].lower() == 'a': click.echo('Whitening the training data using the whiten_all function') - data_dict = whiten_all(data_dict) + # in this case, whitening_parameters is a single tuple + data_dict, whitening_parameters = whiten_all(data_dict) elif config_data['whiten'][0].lower() == 'e': click.echo('Whitening the training data using the whiten_each function') - data_dict = whiten_each(data_dict) + # in this case, whitening_parameters is a dictionary of parameters + data_dict, whitening_parameters = whiten_each(data_dict) else: click.echo('Not whitening the data') @@ -189,7 +192,7 @@ def prepare_model_metadata(data_dict, data_metadata, config_data): for k, v in data_dict.items(): data_dict[k] = v + np.random.randn(*v.shape) * config_data['noise_level'] - return data_dict, model_parameters, train, hold_out + return data_dict, model_parameters, train, hold_out, whitening_parameters def get_heldout_data_splits(data_dict, train_list, hold_out_list): diff --git a/moseq2_model/helpers/wrappers.py b/moseq2_model/helpers/wrappers.py index ad5745f..000dcb5 100644 --- a/moseq2_model/helpers/wrappers.py +++ b/moseq2_model/helpers/wrappers.py @@ -6,12 +6,13 @@ import sys import glob import click +import numpy as np from copy import deepcopy -from collections import OrderedDict -from moseq2_model.train.util import train_model, run_e_step -from os.path import join, basename, realpath, dirname, exists, splitext +from cytoolz import valmap +from moseq2_model.train.util import train_model, run_e_step, apply_model +from os.path import join, basename, realpath, dirname, splitext from moseq2_model.util import (save_dict, load_pcs, get_parameters_from_model, copy_model, get_scan_range_kappas, - create_command_strings, get_current_model, get_loglikelihoods, get_session_groupings) + create_command_strings, get_current_model, get_loglikelihoods, get_session_groupings, load_dict) from moseq2_model.helpers.data import (process_indexfile, select_data_to_model, prepare_model_metadata, graph_modeling_loglikelihoods, get_heldout_data_splits, get_training_data_splits) @@ -56,20 +57,22 @@ def learn_model_wrapper(input_file, dest_file, config_data): load_groups=config_data['load_groups']) # Parse index file and update metadata information; namely groups + # If no group data in pca data, use group info from index file select_groups = config_data.get('select_groups', False) index_data, data_metadata = process_indexfile(config_data.get('index', None), data_metadata, config_data['default_group'], select_groups) # Get keys to include in training set + # TODO: select_groups not implemented if index_data is not None: data_dict, data_metadata = select_data_to_model(index_data, data_dict, data_metadata, select_groups) - + all_keys = list(data_dict) groups = list(data_metadata['groups'].values()) # Get train/held out data split uuids - data_dict, model_parameters, train_list, hold_out_list = \ + data_dict, model_parameters, train_list, hold_out_list, whitening_parameters = \ prepare_model_metadata(data_dict, data_metadata, config_data) # Pack data dicts corresponding to uuids in train_list and hold_out_list @@ -154,7 +157,9 @@ def learn_model_wrapper(input_file, dest_file, config_data): 'hold_out_list': hold_out_list, 'train_list': train_list, 'train_ll': train_ll, - 'expected_states': expected_states if config_data['e_step'] else None + 'expected_states': expected_states if config_data['e_step'] else None, + 'whitening_parameters': whitening_parameters, + 'pc_score_path': os.path.abspath(input_file) } # Save model @@ -168,6 +173,64 @@ def learn_model_wrapper(input_file, dest_file, config_data): return img_path +def apply_model_wrapper(model_file, pc_file, dest_file, config_data): + """ + Wrapper function to apply a pre-trained model to new data. + + Args: + model_file (str): Path to pre-trained model file + pc_file (str): Path to PC scores file + dest_file (str): Path to save output file + + Returns: + None + """ + + assert splitext(basename(dest_file))[-1] in ['.mat', '.z', '.pkl', '.p', '.h5'], 'Incorrect model filetype' + os.makedirs(dirname(dest_file), exist_ok=True) + + if not os.access(dirname(dest_file), os.W_OK): + raise IOError('Output directory is not writable.') + + + # Load model + model_data = load_dict(model_file) + + if model_data.get('whitening_parameters') is None: + raise KeyError('Whitening parameters not found in model file. Unable to apply model to new data. Please retrain the model using the latest version.') + + # Load PC scores + data_dict, data_metadata = load_pcs(filename=pc_file, var_name=config_data.get('var_name', 'scores'), npcs=model_data['run_parameters']['npcs'], + load_groups=config_data.get('load_groups', False)) + + # parse group information from index file + index_data, data_metadata = process_indexfile(config_data.get('index', None), data_metadata, + config_data.get('default_group', 'n/a'), select_groups=False) + + # Apply model + syllables = apply_model(model_data['model'], model_data['whitening_parameters'], data_dict, data_metadata, model_data['run_parameters']['whiten']) + + # add -5 padding to the list of states + nlags = model_data['run_parameters'].get('nlags', 3) + syllables = valmap(lambda v: np.concatenate(([-5] * nlags, v)), syllables) + + # prepare model data dictionary to save + # save applied model data + applied_model_data = {} + applied_model_data['labels'] = list(syllables.values()) + applied_model_data['keys'] = list(syllables.keys()) + applied_model_data['metadata'] = data_metadata + applied_model_data['pc_score_path'] = os.path.abspath(pc_file) + applied_model_data['pre_trained_model_path'] = os.path.abspath(model_file) + + # copy over pre-trained model data + for key in ['model_parameters', 'run_parameters', 'model', 'whitening_parameters']: + applied_model_data[key] = model_data[key] + + # Save output + save_dict(filename=dest_file, obj_to_save=applied_model_data) + + def kappa_scan_fit_models_wrapper(input_file, config_data, output_dir): """ Wrapper function to output multiple model training commands for a range of kappa values. diff --git a/moseq2_model/train/models.py b/moseq2_model/train/models.py index 6b3898a..857ca7b 100644 --- a/moseq2_model/train/models.py +++ b/moseq2_model/train/models.py @@ -1,5 +1,5 @@ """ -ARHMM model initialization utilities. +ARHMM initialization utilities. """ import warnings @@ -54,7 +54,7 @@ def ARHMM(data_dict, kappa=1e6, gamma=999, nlags=3, alpha=5.7, affine=True, model_hypparams={}, obs_hypparams={}, sticky_init=False, separate_trans=False, groups=None, robust=False, silent=False): """ - Initialize ARHMM and add data and group labels to the ARHMM model. + Initialize ARHMM and add data and group labels to the ARHMM. Args: data_dict (OrderedDict): training data to add to model diff --git a/moseq2_model/train/util.py b/moseq2_model/train/util.py index e8ce2ca..379e653 100644 --- a/moseq2_model/train/util.py +++ b/moseq2_model/train/util.py @@ -1,13 +1,10 @@ """ ARHMM utility functions """ - -import math import numpy as np -from cytoolz import valmap from tqdm.auto import tqdm -from scipy.stats import norm from functools import partial +from cytoolz import valmap, itemmap from collections import OrderedDict, defaultdict from moseq2_model.util import save_arhmm_checkpoint, get_loglikelihoods @@ -147,7 +144,7 @@ def get_labels_from_model(model): Grab model labels for each training dataset and place them in a list. Args: - model (ARHMM): trained ARHMM model + model (ARHMM): trained ARHMM Returns: labels (list): An array of predicted syllable labels for each training session @@ -157,6 +154,44 @@ def get_labels_from_model(model): return labels +def apply_model(model, whitening_params, data_dict, metadata, whiten='all'): + ''' + Apply pre-trained model to data_dict. Note that this function might produce unexpected behavior + if the model was trained using separate transition matrices for different groups of sessions. + + Args: + model (ARHMM): pre-trained model + whitening_params (namedtuple or dict): whitening parameters + data_dict (OrderedDict): data to apply model to + metadata (dict): metadata for data_dict + + Returns: + labels (dict): dictionary of labels predicted per session after modeling + ''' + + # whiten data function + mu, L, offset = whitening_params['mu'], whitening_params['L'], whitening_params['offset'] + apply_whitening = lambda x: np.linalg.solve(L, (x-mu).T).T + offset + + # check for whiten parameters to see if whiten_all or whiten_each + if whiten[0].lower() == 'e': + # this approach is not recommended, but supported + center = whitening_params[list(whitening_params)[0]]['offset'] == 0 + whitened_data, _ = whiten_each(data_dict, center) + else: + whitened_data = valmap(apply_whitening, data_dict) + + # apply model to data + if 'SeparateTrans' in str(type(model)): + # not recommended, but supported + labels = itemmap(lambda item: (item[0], model.heldout_viterbi(item[1], group_id=metadata['groups'][item[0]])), whitened_data) + else: + labels = valmap(model.heldout_viterbi, whitened_data) + + return labels + + + # taken from moseq by @mattjj and @alexbw def whiten_all(data_dict, center=True): """ @@ -178,9 +213,10 @@ def whiten_all(data_dict, center=True): L = np.linalg.cholesky(Sigma) offset = 0. if center else mu + # set up function to whiten data apply_whitening = lambda x: np.linalg.solve(L, (x-mu).T).T + offset - - return OrderedDict((k, contig(apply_whitening(v))) for k, v in data_dict.items()) + whitening_parameters = {'mu': mu, 'L': L, 'offset': offset} + return OrderedDict((k, contig(apply_whitening(v))) for k, v in data_dict.items()), whitening_parameters # taken from moseq by @mattjj and @alexbw @@ -195,12 +231,12 @@ def whiten_each(data_dict, center=True): Returns: data_dict (OrderedDict): Whitened training data dictionary """ - + whitening_parameters = {} for k, v in data_dict.items(): - tmp_dict = whiten_all({k: v}, center=center) + tmp_dict, whitening_parameters[k] = whiten_all({k: v}, center=center) data_dict[k] = tmp_dict[k] - return data_dict + return data_dict, whitening_parameters def run_e_step(arhmm): diff --git a/moseq2_model/util.py b/moseq2_model/util.py index 4716d5e..5ba3ac3 100644 --- a/moseq2_model/util.py +++ b/moseq2_model/util.py @@ -1,11 +1,11 @@ """ Utility functions for handling loading and saving models and their respective metadata. """ +import os import re import h5py import click import joblib -import pickle import scipy.io import warnings import numpy as np @@ -85,8 +85,11 @@ def load_pcs(filename, var_name="features", load_groups=False, npcs=10): if 'groups' in f: metadata['groups'] = {key: f['groups'][i] for i, key in enumerate(data_dict) if key in f['metadata']} else: - warnings.warn('groups key not found in h5 file, assigning each session to unique group...') + warnings.warn('groups key not found in h5 file, assigning each session to unique group if no moseq2-index.yaml') metadata['groups'] = {key: i for i, key in enumerate(data_dict)} + else: + warnings.warn('groups not loaded from h5 file. Assigning each session to unique group if no moseq2-index.yaml') + metadata['groups'] = {key: i for i, key in enumerate(data_dict)} else: raise IOError('Could not load data from h5 file') else: @@ -162,7 +165,7 @@ def get_loglikelihoods(arhmm, data, groups, separate_trans, normalize=True): Compute the log-likelihoods of the training sessions. Args: - arhmm (ARHMM): ARHMM model. + arhmm (ARHMM): the ARHMM model object. data (dict): dict object with UUID keys containing the PCS used for training. groups (list): list of assigned groups for all corresponding session uuids. separate_trans (bool): flag to compute separate log-likelihoods for each modeled group. @@ -232,6 +235,24 @@ def save_dict(filename, obj_to_save=None): else: raise ValueError('Did not understand filetype') + +def load_dict(filename): + """ + Load dictionary from file. + + Args: + filename (str): path to file where dict is saved + + Returns: + obj (dict): loaded dictionary + """ + if filename.endswith(".h5"): + return h5_to_dict(filename) + elif filename.endswith(("pkl", "p", "z")): + return joblib.load(filename) + else: + raise ValueError(f"Does not support filetype {os.path.splitext(filename)[1]}") + def dict_to_h5(h5file, export_dict, path='/'): """ @@ -298,7 +319,7 @@ def save_arhmm_checkpoint(filename: str, arhmm: dict): Args: filename (str): path that specifies the checkpoint - arhmm (dict): a dictionary containing the arhmm model object, training iteration number, log-likelihoods of each training step, and labels for each step. + arhmm (dict): a dictionary containing the arhmm object, training iteration number, log-likelihoods of each training step, and labels for each step. """ # Getting model object diff --git a/tests/integration_tests/test_data_helper.py b/tests/integration_tests/test_data_helper.py index b0ea783..bc94784 100644 --- a/tests/integration_tests/test_data_helper.py +++ b/tests/integration_tests/test_data_helper.py @@ -76,7 +76,7 @@ def test_prepare_model_metadata(self): data_dict, data_metadata = select_data_to_model(index_data, data_dict, data_metadata) - data_dict1, model_parameters, train_list, hold_out_list = \ + data_dict1, model_parameters, train_list, hold_out_list, whitening_parameters = \ prepare_model_metadata(data_dict, data_metadata, config_data) assert data_dict.values() != data_dict1.values(), "Index loaded uuids and training data does not match scores file" @@ -85,7 +85,7 @@ def test_prepare_model_metadata(self): config_data['whiten'] = 'each' config_data['noise_level'] = 1 - data_dict1, model_parameters, train_list, hold_out_list = \ + data_dict1, model_parameters, train_list, hold_out_list, whitening_parameters = \ prepare_model_metadata(data_dict, data_metadata, config_data) assert data_dict.values() != data_dict1.values(), "Index loaded uuids and training data does not match scores file" @@ -93,7 +93,7 @@ def test_prepare_model_metadata(self): assert hold_out_list == [], "Some of the data is unintentionally held out" config_data['whiten'] = 'none' - data_dict1, model_parameters, train_list, hold_out_list = \ + data_dict1, model_parameters, train_list, hold_out_list, whitening_parameters = \ prepare_model_metadata(data_dict, data_metadata, config_data) assert data_dict.values() != data_dict1.values(), "Index loaded uuids and training data does not match scores file" @@ -121,7 +121,7 @@ def test_get_heldout_data_splits(self): data_dict, data_metadata = select_data_to_model(index_data, data_dict, data_metadata) - data_dict, model_parameters, train_list, hold_out_list = \ + data_dict, model_parameters, train_list, hold_out_list, whitening_parameters = \ prepare_model_metadata(data_dict, data_metadata, config_data) assert (sorted(train_list) != sorted(hold_out_list)), "Training data is the same as held out data" diff --git a/tests/unit_tests/test_train_models.py b/tests/unit_tests/test_train_models.py index 054103f..26fdb5a 100644 --- a/tests/unit_tests/test_train_models.py +++ b/tests/unit_tests/test_train_models.py @@ -60,7 +60,7 @@ def test_ARHMM(self): nkeys = 5 all_keys = ['key1', 'key2', 'key3', 'key4', 'key5'] - data_dict, model_parameters, train_list, hold_out_list = \ + data_dict, model_parameters, train_list, hold_out_list, whitening_parameters = \ prepare_model_metadata(data_dict, data_metadata, config_data) model_parameters['separate_trans'] = False diff --git a/tests/unit_tests/test_train_utils.py b/tests/unit_tests/test_train_utils.py index ed4a597..40616ac 100644 --- a/tests/unit_tests/test_train_utils.py +++ b/tests/unit_tests/test_train_utils.py @@ -28,7 +28,7 @@ def get_model(separate_trans=False, robust=False, groups=[]): load_groups=True) data_metadata['groups'] = {k: g for k, g in zip(data_dict, groups)} - data_dict, model_parameters, _, _ = \ + data_dict, model_parameters, _, _, whitening_parameters = \ prepare_model_metadata(data_dict, data_metadata, config_data) arhmm = ARHMM(data_dict=data_dict, **model_parameters) @@ -45,7 +45,7 @@ def test_train_model(self): model, data_dict = get_model() - X = whiten_all(data_dict) + X, whitening_parameters = whiten_all(data_dict) training_data, validation_data = get_training_data_splits(config_data['percent_split'] / 100, X) model, lls, labels, iter_lls, iter_holls, _ = train_model(model, num_iter=5, train_data=training_data, @@ -68,7 +68,7 @@ def test_train_model(self): model, data_dict = get_model(separate_trans=True, groups=['default', 'Group1']) - X = whiten_all(data_dict) + X, whitening_parameters = whiten_all(data_dict) training_data, validation_data = get_training_data_splits(config_data['percent_split'] / 100, X) model, lls, labels, iter_lls, iter_holls, _ = train_model(model, num_iter=5, train_data=training_data, @@ -91,7 +91,7 @@ def test_get_labels_from_model(self): model, data_dict = get_model() - X = whiten_all(data_dict) + X, whitening_parameters = whiten_all(data_dict) training_data, validation_data = get_training_data_splits(config_data['percent_split'] / 100, X) model, lls, labels, iter_lls, iter_holls, _ = train_model(model, num_iter=5, train_data=training_data, @@ -105,16 +105,16 @@ def test_whiten_all(self): _, data_dict = get_model() - whitened_a = whiten_all(data_dict) - whitened_e = whiten_each(data_dict) + whitened_a, whitening_parameters = whiten_all(data_dict) + whitened_e, whitening_parameters = whiten_each(data_dict) assert data_dict.values() != whitened_a.values() assert whitened_a.values() != whitened_e.values() def test_whiten_each(self): _, data_dict = get_model() - whitened_a = whiten_all(data_dict, center=False) - whitened_e = whiten_each(data_dict, center=False) + whitened_a, whitening_parameters = whiten_all(data_dict, center=False) + whitened_e, whitening_parameters = whiten_each(data_dict, center=False) assert data_dict.values() != whitened_a.values() assert whitened_a.values() != whitened_e.values() diff --git a/tests/unit_tests/test_util.py b/tests/unit_tests/test_util.py index 30cbd98..87baa11 100644 --- a/tests/unit_tests/test_util.py +++ b/tests/unit_tests/test_util.py @@ -309,7 +309,7 @@ def test_copy_model(self): with open(config_file, 'r') as f: config_data = yaml.safe_load(f) - X = whiten_all(data_dict) + X, whitening_parameters = whiten_all(data_dict) training_data, validation_data = get_training_data_splits(config_data['percent_split'] / 100, X) model, lls, labels, iter_lls, iter_holls, _ = train_model(model, num_iter=5, train_data=training_data, @@ -335,7 +335,7 @@ def check_params(model, params): with open(config_file, 'r') as f: config_data = yaml.safe_load(f) - X = whiten_all(data_dict) + X, whitening_parameters = whiten_all(data_dict) training_data, validation_data = get_training_data_splits(config_data['percent_split'] / 100, X) model, lls, labels, iter_lls, iter_holls, _ = train_model(model, num_iter=5, train_data=training_data,