From b48e9acb50b856b148c720e6d235e479eac4ab33 Mon Sep 17 00:00:00 2001 From: Mats Date: Tue, 7 Jan 2025 14:49:24 +0000 Subject: [PATCH 01/10] add glm wrappers to preproc module --- osl_ephys/preprocessing/batch.py | 28 +++- osl_ephys/preprocessing/osl_wrappers.py | 208 +++++++++++++++++++++++- 2 files changed, 234 insertions(+), 2 deletions(-) diff --git a/osl_ephys/preprocessing/batch.py b/osl_ephys/preprocessing/batch.py index ee41e82..61ab5f9 100644 --- a/osl_ephys/preprocessing/batch.py +++ b/osl_ephys/preprocessing/batch.py @@ -459,7 +459,7 @@ def write_dataset(dataset, outbase, run_id, ftype='preproc-raw', overwrite=False outnames = {"raw": outbase.format(run_id=run_id, ftype=ftype, fext="fif")} if Path(outnames["raw"]).exists() and not overwrite: raise ValueError( - "{} already exists. Please delete or do use overwrite=True.".format(fif_outname) + "{} already exists. Please delete or do use overwrite=True.".format(outnames['raw']) ) logger.info(f"Saving dataset['raw'] as {outnames['raw']}") dataset["raw"].save(outnames['raw'], overwrite=overwrite) @@ -494,6 +494,14 @@ def write_dataset(dataset, outbase, run_id, ftype='preproc-raw', overwrite=False logger.info(f"Saving dataset['glm'] as {outnames['glm']}") dataset["glm"].save_pkl(outnames['glm'], overwrite=overwrite) + if "fig" in dataset and "fig" not in skip and dataset['fig'] is not None: + keys = dataset["fig"].keys() + outnames['fig'] = {} + for key in keys: + outnames['fig'][key] = outbase.format(run_id=run_id, ftype=key, fext="png") + logger.info(f"Saving dataset['fig'][{key}] as {outnames['fig'][key]}") + dataset["fig"].savefig(outnames['fig'][key], overwrite=overwrite) + # save remaining keys as pickle files for key in dataset: if key not in outnames and key not in skip: @@ -713,6 +721,7 @@ def run_proc_chain( overwrite=False, skip_save=None, extra_funcs=None, + covs=None, random_seed='auto', verbose="INFO", mneverbose="WARNING", @@ -745,6 +754,8 @@ def run_proc_chain( List of keys to skip writing to disk. If None, we don't skip any keys. extra_funcs : list User-defined functions. + covs : dict, pd.DataFrame, or None + Covariates for GLM. random_seed : 'auto' (default), int or None Random seed to set. If 'auto', a random seed will be generated. Random seeds are set for both Python and NumPy. If None, no random seed is set. @@ -857,6 +868,8 @@ def run_proc_chain( "epochs": None, "event_id": config["meta"]["event_codes"], "ica": None, + "covs": covs, + "fig": {}, } # Do the preprocessing @@ -864,6 +877,7 @@ def run_proc_chain( method, userargs = next(iter(stage.items())) target = userargs.get("target", "raw") # Raw is default func = find_func(method, target=target, extra_funcs=extra_funcs) + # Actual function call dataset = func(dataset, userargs) @@ -959,6 +973,7 @@ def run_proc_batch( overwrite=False, skip_save=None, extra_funcs=None, + covs=None, random_seed='auto', verbose="INFO", mneverbose="WARNING", @@ -995,6 +1010,8 @@ def run_proc_batch( List of keys to skip writing to disk. If None, we don't skip any keys. extra_funcs : list User-defined functions. + covs : dict or pd.DataFrame + Covariates to use for building the GLM design random_seed : 'auto' (default), int or None Random seed to set. If 'auto', a random seed will be generated. Random seeds are set for both Python and NumPy. If None, no random seed is set. @@ -1102,6 +1119,7 @@ def run_proc_batch( overwrite=overwrite, skip_save=skip_save, extra_funcs=extra_funcs, + covs=covs, random_seed=random_seed, ) @@ -1153,11 +1171,19 @@ def run_proc_batch( for key in group_inputs[0]: dataset[key] = [group_inputs[i][key] for i in range(len(group_inputs))] skip_save.append(key) + + if covs is not None: + dataset['covs'] = covs + dataset['fig'] = {} + for stage in deepcopy(config["group"]): method, userargs = next(iter(stage.items())) + # make sure the function always knows it's a group processing + userargs['run_on_group'] = True target = userargs.get("target", "raw") # Raw is default # skip.append(stage if userargs.get("skip_save") is True else None) # skip saving this stage to disk func = find_func(method, target=target, extra_funcs=extra_funcs) + # Actual function call dataset = func(dataset, userargs) outbase = os.path.join(outdir, "{ftype}.{fext}") diff --git a/osl_ephys/preprocessing/osl_wrappers.py b/osl_ephys/preprocessing/osl_wrappers.py index 175a600..f2f69a1 100644 --- a/osl_ephys/preprocessing/osl_wrappers.py +++ b/osl_ephys/preprocessing/osl_wrappers.py @@ -15,7 +15,9 @@ from os.path import exists from scipy import stats from pathlib import Path - +import glmtools +from ..glm import glm_epochs, glm_spectrum, glm_irasa, group_glm_epochs, group_glm_spectrum, MaxStatPermuteGLMSpectrum, ClusterPermuteGLMSpectrum +from ..glm.glm_base import SensorMaxStatPerm, SensorClusterPerm logger = logging.getLogger(__name__) @@ -922,3 +924,207 @@ def run_osl_ica_manualreject(dataset, userargs): else: logger.info("Components were not removed from raw data") return dataset + +#%% GLM wrappers + +def zscore_present_data(dataset, userargs): + """ + z-scoring parametric regressors, without NaNs + Nans will be zeros in the z-scored version + + Parameters + ---------- + dataset: dict + Dictionary containing at least an MNE object with the key ``covs``. + userargs: dict + Dictionary of additional arguments containing the keys ``keys``. + """ + keys = userargs.pop("keys", None) + # make sure keys is a single string or list of strings + if keys[0]=='[' and keys[-1]==']': + keys = keys[1:-1].split(' ') + + for key in keys: + new = stats.zscore(dataset["covs"][key], nan_policy='omit') + new[np.isnan(dataset["covs"][key])] = 0 + dataset["covs"][key] = new + return dataset + + +def glm_add_regressor(dataset, userargs): + """osl-ephys Batch wrapper for :py:func:`osl_ephys.preprocessing.osl_glm.add_regressor `. + + Parameters + """ + logger.info("osl-ephys Stage - {0}".format("GLM Add Regressor")) + if 'design_config' not in dataset: + dataset['design_config'] = glmtools.design.DesignConfig() + + rtype = userargs.pop("rtype", None) + name = userargs.pop("name", None) + codes = userargs.pop("codes", None) + preproc = userargs.pop("preproc", None) + key = userargs.pop("key", None) + + if rtype == 'Constant': + dataset['design_config'].add_regressor(name, rtype) + elif rtype == 'Categorical': + codes = [ + float(codes) + if np.logical_or(type(codes) == int, type(codes) == float) + else np.array(codes.split(" ")).astype(float) + ] + dataset['design_config'].add_regressor(name, rtype, codes=codes) + elif rtype == 'Parametric': + dataset['design_config'].add_regressor(name, rtype, datainfo=key, preproc=preproc) + elif rtype == 'MeanEffects': + dataset['design_config'].add_regressor(name=name + '_{0}',rtype=rtype, datainfo=key) + else: + raise ValueError("Unknown regressor type") + return dataset + + +def glm_add_contrast(dataset, userargs): + """osl-ephys Batch wrapper for :py:func:`osl_ephys.preprocessing.osl_glm.add_regressor `. + + Parameters + """ + logger.info("osl-ephys Stage - {0}".format("GLM Add Contrast")) + + simple = userargs.pop("simple", False) + name = userargs.pop("name", None) + values = userargs.pop("values", None) + + if simple: + dataset['design_config'].add_simple_contrasts() + else: + import re + def string_to_dict(input_string): + # Replace unquoted keys with quoted keys + input_string = re.sub(r'(? Date: Tue, 7 Jan 2025 15:37:54 +0000 Subject: [PATCH 02/10] add regressor option to glm wrapper --- osl_ephys/preprocessing/osl_wrappers.py | 27 ++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/osl_ephys/preprocessing/osl_wrappers.py b/osl_ephys/preprocessing/osl_wrappers.py index f2f69a1..5900b76 100644 --- a/osl_ephys/preprocessing/osl_wrappers.py +++ b/osl_ephys/preprocessing/osl_wrappers.py @@ -15,6 +15,7 @@ from os.path import exists from scipy import stats from pathlib import Path +import matplotlib.pyplot as plt import glmtools from ..glm import glm_epochs, glm_spectrum, glm_irasa, group_glm_epochs, group_glm_spectrum, MaxStatPermuteGLMSpectrum, ClusterPermuteGLMSpectrum from ..glm.glm_base import SensorMaxStatPerm, SensorClusterPerm @@ -969,12 +970,17 @@ def glm_add_regressor(dataset, userargs): if rtype == 'Constant': dataset['design_config'].add_regressor(name, rtype) elif rtype == 'Categorical': - codes = [ - float(codes) - if np.logical_or(type(codes) == int, type(codes) == float) - else np.array(codes.split(" ")).astype(float) - ] - dataset['design_config'].add_regressor(name, rtype, codes=codes) + if codes == 'unique': # add a regressor for each unique value + codes = np.unique(dataset['covs'][key]) + for code in codes: + dataset['design_config'].add_regressor(name + '_{0}'.format(code), rtype, codes=code) + else: + codes = [ + float(codes) + if np.logical_or(type(codes) == int, type(codes) == float) + else np.array(codes.split(" ")).astype(float) + ] + dataset['design_config'].add_regressor(name, rtype, codes=codes) elif rtype == 'Parametric': dataset['design_config'].add_regressor(name, rtype, datainfo=key, preproc=preproc) elif rtype == 'MeanEffects': @@ -1010,7 +1016,7 @@ def string_to_dict(input_string): return dataset -def glm(dataset, userargs): +def glm_fit(dataset, userargs): """ wrapper for the different glm functions in the glm module Parameters @@ -1110,6 +1116,9 @@ def glm_permutations(dataset, userargs): if type is None: raise ValueError("type not specified (e.g. 'max', 'cluster')") + thresh = userargs.pop("thresh", 95) + plot_sig = userargs.pop("plot_sig", True) + contrast = userargs.pop("contrast", None) contrast = dataset[target].contrast_names.index(contrast) fl_contrast = userargs.pop("fl_contrast", 0) @@ -1127,4 +1136,8 @@ def glm_permutations(dataset, userargs): elif method == 'spectrum' or method == 'glm_spectrum': dataset[name] = ClusterPermuteGLMSpectrum(dataset[target], contrast, fl_contrast, **userargs) + if plot_sig: + fig, ax = plt.subplots() + dataset[name].plot_sig_clusters(thresh, ax=ax) + dataset['fig'][name + 'sig' + thresh] = fig return dataset \ No newline at end of file From 8ba8cd63905c83c492ee3c40b402ffe3f05fe9ad Mon Sep 17 00:00:00 2001 From: Mats Date: Tue, 7 Jan 2025 15:43:13 +0000 Subject: [PATCH 03/10] add glm_fit defaults --- osl_ephys/preprocessing/osl_wrappers.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/osl_ephys/preprocessing/osl_wrappers.py b/osl_ephys/preprocessing/osl_wrappers.py index 5900b76..c4ea177 100644 --- a/osl_ephys/preprocessing/osl_wrappers.py +++ b/osl_ephys/preprocessing/osl_wrappers.py @@ -954,8 +954,13 @@ def zscore_present_data(dataset, userargs): def glm_add_regressor(dataset, userargs): """osl-ephys Batch wrapper for :py:func:`osl_ephys.preprocessing.osl_glm.add_regressor `. - + Parameters + ---------- + dataset: dict + Dictionary containing at least an MNE object with the key ``covs``. + userargs: dict + Dictionary of additional arguments containing the keys ``keys``. """ logger.info("osl-ephys Stage - {0}".format("GLM Add Regressor")) if 'design_config' not in dataset: @@ -1032,11 +1037,25 @@ def glm_fit(dataset, userargs): Input dictionary containing MNE objects that have been modified in place. """ run_on_group = userargs.pop("run_on_group", False) - target = userargs.pop("target", "raw") - name = userargs.pop("name", "glm") + method = userargs.pop("method", None) if method is None: raise ValueError("method not specified") + target = userargs.pop("target", None) + if target is None: + if run_on_group: + target = "glm" + else: + if method in ['epochs', 'glm_epochs']: + target = "epochs" + elif method in ['spectrum', 'glm_spectrum']: + target = "raw" + name = userargs.pop("name", None) + if name is None: + if run_on_group: + name = "group_glm" + else: + name = "glm" metric = userargs.pop("metric", 'copes') plot_summary = userargs.pop("plot_summary", True) From cc44eac7cc88a045c0ef4735490285ee7ceaf900 Mon Sep 17 00:00:00 2001 From: Mats Date: Tue, 7 Jan 2025 15:46:37 +0000 Subject: [PATCH 04/10] add glm perm defaults --- osl_ephys/preprocessing/osl_wrappers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/osl_ephys/preprocessing/osl_wrappers.py b/osl_ephys/preprocessing/osl_wrappers.py index c4ea177..d482b1a 100644 --- a/osl_ephys/preprocessing/osl_wrappers.py +++ b/osl_ephys/preprocessing/osl_wrappers.py @@ -1126,8 +1126,8 @@ def glm_permutations(dataset, userargs): dataset: dict Input dictionary containing MNE objects that have been modified in place. """ - target = userargs.pop("target", "glm") - name = userargs.pop("name", "glm_perm") + target = userargs.pop("target", "group_glm") + name = userargs.pop("name", "group_glm_perm") method = userargs.pop("method", None) if method is None: raise ValueError("method not specified") @@ -1135,7 +1135,7 @@ def glm_permutations(dataset, userargs): if type is None: raise ValueError("type not specified (e.g. 'max', 'cluster')") - thresh = userargs.pop("thresh", 95) + thresh = userargs.pop("threshold", 95) plot_sig = userargs.pop("plot_sig", True) contrast = userargs.pop("contrast", None) From fb22e87a7b250093e92f38f3b2d626579c9dc39d Mon Sep 17 00:00:00 2001 From: Mats Date: Tue, 7 Jan 2025 15:59:36 +0000 Subject: [PATCH 05/10] add glm contrast option --- osl_ephys/preprocessing/osl_wrappers.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/osl_ephys/preprocessing/osl_wrappers.py b/osl_ephys/preprocessing/osl_wrappers.py index d482b1a..9494780 100644 --- a/osl_ephys/preprocessing/osl_wrappers.py +++ b/osl_ephys/preprocessing/osl_wrappers.py @@ -1009,13 +1009,17 @@ def glm_add_contrast(dataset, userargs): if simple: dataset['design_config'].add_simple_contrasts() else: - import re - def string_to_dict(input_string): - # Replace unquoted keys with quoted keys - input_string = re.sub(r'(? Date: Tue, 7 Jan 2025 16:11:20 +0000 Subject: [PATCH 06/10] remove covs input from proc_chain --- osl_ephys/preprocessing/batch.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/osl_ephys/preprocessing/batch.py b/osl_ephys/preprocessing/batch.py index 61ab5f9..de108df 100644 --- a/osl_ephys/preprocessing/batch.py +++ b/osl_ephys/preprocessing/batch.py @@ -721,7 +721,6 @@ def run_proc_chain( overwrite=False, skip_save=None, extra_funcs=None, - covs=None, random_seed='auto', verbose="INFO", mneverbose="WARNING", @@ -868,7 +867,6 @@ def run_proc_chain( "epochs": None, "event_id": config["meta"]["event_codes"], "ica": None, - "covs": covs, "fig": {}, } From dcfd22e729b795558494ff278e6baa6241394e67 Mon Sep 17 00:00:00 2001 From: Mats Date: Tue, 7 Jan 2025 16:11:34 +0000 Subject: [PATCH 07/10] fix circular import --- osl_ephys/report/preproc_report.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/osl_ephys/report/preproc_report.py b/osl_ephys/report/preproc_report.py index e978efe..afa4ecd 100644 --- a/osl_ephys/report/preproc_report.py +++ b/osl_ephys/report/preproc_report.py @@ -35,12 +35,6 @@ from ..utils import process_file_inputs, validate_outdir from ..utils.logger import log_or_print -from ..preprocessing import ( - read_dataset, - load_config, - get_config_from_fif, - plot_preproc_flowchart, -) # ---------------------------------------------------------------------------------- @@ -64,6 +58,7 @@ def gen_report_from_fif(infiles, outdir, ftype=None, logsdir=None, run_id=None): run_id : str Run ID. """ + from ..preprocessing import read_dataset # Validate input files and directory to save html file and plots to infiles, outnames, good_files = process_file_inputs(infiles) @@ -493,7 +488,7 @@ def plot_flowchart(raw, savebase=None): Path to saved figure. """ - + from ..preprocessing import get_config_from_fif, plot_preproc_flowchart # Get config info from raw.info['description'] config_list = get_config_from_fif(raw) From 54afaef1ab1e79feaae61d2201b126a7ec3f5fac Mon Sep 17 00:00:00 2001 From: Mats Date: Tue, 7 Jan 2025 16:15:43 +0000 Subject: [PATCH 08/10] remove covs input from proc_chain --- osl_ephys/preprocessing/batch.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/osl_ephys/preprocessing/batch.py b/osl_ephys/preprocessing/batch.py index de108df..e4c7eef 100644 --- a/osl_ephys/preprocessing/batch.py +++ b/osl_ephys/preprocessing/batch.py @@ -753,8 +753,6 @@ def run_proc_chain( List of keys to skip writing to disk. If None, we don't skip any keys. extra_funcs : list User-defined functions. - covs : dict, pd.DataFrame, or None - Covariates for GLM. random_seed : 'auto' (default), int or None Random seed to set. If 'auto', a random seed will be generated. Random seeds are set for both Python and NumPy. If None, no random seed is set. @@ -1117,7 +1115,6 @@ def run_proc_batch( overwrite=overwrite, skip_save=skip_save, extra_funcs=extra_funcs, - covs=covs, random_seed=random_seed, ) From 3b44f96367bc3d2b49ecc6abbd728c802c61a51b Mon Sep 17 00:00:00 2001 From: Mats Date: Tue, 7 Jan 2025 16:37:55 +0000 Subject: [PATCH 09/10] add run_osl to glm wrappers --- osl_ephys/preprocessing/osl_wrappers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/osl_ephys/preprocessing/osl_wrappers.py b/osl_ephys/preprocessing/osl_wrappers.py index 9494780..dc31397 100644 --- a/osl_ephys/preprocessing/osl_wrappers.py +++ b/osl_ephys/preprocessing/osl_wrappers.py @@ -928,7 +928,7 @@ def run_osl_ica_manualreject(dataset, userargs): #%% GLM wrappers -def zscore_present_data(dataset, userargs): +def run_osl_zscore_present_data(dataset, userargs): """ z-scoring parametric regressors, without NaNs Nans will be zeros in the z-scored version @@ -952,7 +952,7 @@ def zscore_present_data(dataset, userargs): return dataset -def glm_add_regressor(dataset, userargs): +def run_osl_glm_add_regressor(dataset, userargs): """osl-ephys Batch wrapper for :py:func:`osl_ephys.preprocessing.osl_glm.add_regressor `. Parameters @@ -995,7 +995,7 @@ def glm_add_regressor(dataset, userargs): return dataset -def glm_add_contrast(dataset, userargs): +def run_osl_glm_add_contrast(dataset, userargs): """osl-ephys Batch wrapper for :py:func:`osl_ephys.preprocessing.osl_glm.add_regressor `. Parameters @@ -1025,7 +1025,7 @@ def string_to_dict(input_string): return dataset -def glm_fit(dataset, userargs): +def run_osl_glm_fit(dataset, userargs): """ wrapper for the different glm functions in the glm module Parameters @@ -1115,7 +1115,7 @@ def glm_fit(dataset, userargs): return dataset -def glm_permutations(dataset, userargs): +def run_osl_glm_permutations(dataset, userargs): """ wrapper for the different permutation options in the glm module Parameters From 8739ab2a162a977436d18c4badc5e8f6157f55cf Mon Sep 17 00:00:00 2001 From: Mats Date: Tue, 7 Jan 2025 18:09:18 +0000 Subject: [PATCH 10/10] glm wrapper bug fixes --- osl_ephys/preprocessing/batch.py | 2 +- osl_ephys/preprocessing/osl_wrappers.py | 34 ++++++++++++------------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/osl_ephys/preprocessing/batch.py b/osl_ephys/preprocessing/batch.py index e4c7eef..c180ee9 100644 --- a/osl_ephys/preprocessing/batch.py +++ b/osl_ephys/preprocessing/batch.py @@ -500,7 +500,7 @@ def write_dataset(dataset, outbase, run_id, ftype='preproc-raw', overwrite=False for key in keys: outnames['fig'][key] = outbase.format(run_id=run_id, ftype=key, fext="png") logger.info(f"Saving dataset['fig'][{key}] as {outnames['fig'][key]}") - dataset["fig"].savefig(outnames['fig'][key], overwrite=overwrite) + dataset["fig"][key].savefig(outnames['fig'][key]) # save remaining keys as pickle files for key in dataset: diff --git a/osl_ephys/preprocessing/osl_wrappers.py b/osl_ephys/preprocessing/osl_wrappers.py index dc31397..2c1cc73 100644 --- a/osl_ephys/preprocessing/osl_wrappers.py +++ b/osl_ephys/preprocessing/osl_wrappers.py @@ -963,7 +963,7 @@ def run_osl_glm_add_regressor(dataset, userargs): Dictionary of additional arguments containing the keys ``keys``. """ logger.info("osl-ephys Stage - {0}".format("GLM Add Regressor")) - if 'design_config' not in dataset: + if 'design_config' not in dataset or not isinstance(dataset['design_config'], glmtools.design.DesignConfig): dataset['design_config'] = glmtools.design.DesignConfig() rtype = userargs.pop("rtype", None) @@ -978,16 +978,14 @@ def run_osl_glm_add_regressor(dataset, userargs): if codes == 'unique': # add a regressor for each unique value codes = np.unique(dataset['covs'][key]) for code in codes: - dataset['design_config'].add_regressor(name + '_{0}'.format(code), rtype, codes=code) + dataset['design_config'].add_regressor(name=name + '_{0}'.format(code), rtype=rtype, codes=code) else: - codes = [ - float(codes) + codes = [float(codes) if np.logical_or(type(codes) == int, type(codes) == float) - else np.array(codes.split(" ")).astype(float) - ] - dataset['design_config'].add_regressor(name, rtype, codes=codes) + else np.array(codes[0].split(" ")).astype(float)][0] + dataset['design_config'].add_regressor(name=name, rtype=rtype, codes=codes) elif rtype == 'Parametric': - dataset['design_config'].add_regressor(name, rtype, datainfo=key, preproc=preproc) + dataset['design_config'].add_regressor(name=name, rtype=rtype, datainfo=key, preproc=preproc) elif rtype == 'MeanEffects': dataset['design_config'].add_regressor(name=name + '_{0}',rtype=rtype, datainfo=key) else: @@ -1005,21 +1003,20 @@ def run_osl_glm_add_contrast(dataset, userargs): simple = userargs.pop("simple", False) name = userargs.pop("name", None) values = userargs.pop("values", None) + key = userargs.pop("key", None) if simple: dataset['design_config'].add_simple_contrasts() else: if values == 'unique': - values = np.unique(dataset['covs'][name]) - values={f"{name}_{v}": 1/len(values) for v in values} + values = np.unique(dataset['covs'][key]) + values={f"{key}_{v}": 1/len(values) for v in values} else: - import re - def string_to_dict(input_string): - # Replace unquoted keys with quoted keys - input_string = re.sub(r'(?