From c840bc6a74e15f2d0a626763fe22ab0d5ebc59b6 Mon Sep 17 00:00:00 2001 From: matsvanes Date: Mon, 9 Sep 2024 14:55:40 +0100 Subject: [PATCH 01/13] update help info --- osl/preprocessing/batch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/osl/preprocessing/batch.py b/osl/preprocessing/batch.py index e5c7477..b5f3efb 100644 --- a/osl/preprocessing/batch.py +++ b/osl/preprocessing/batch.py @@ -912,11 +912,11 @@ def run_proc_batch( Should we generate a report? overwrite : bool Should we overwrite the output file if it exists? + extra_funcs : list + User-defined functions. random_seed : 'auto' (default), int or None Random seed to set. If 'auto', a random seed will be generated. Random seeds are set for both Python and NumPy. If None, no random seed is set. - extra_funcs : list - User-defined functions. verbose : str Level of info to print. Can be: ``'CRITICAL'``, ``'ERROR'``, ``'WARNING'``, ``'INFO'``, ``'DEBUG'`` or ``'NOTSET'``. From ed7141384ddf544fc83746e81a4a579da76356bd Mon Sep 17 00:00:00 2001 From: matsvanes Date: Mon, 9 Sep 2024 14:55:57 +0100 Subject: [PATCH 02/13] add read_dataset func --- osl/preprocessing/osl_wrappers.py | 79 +++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/osl/preprocessing/osl_wrappers.py b/osl/preprocessing/osl_wrappers.py index b31a909..8710fe3 100644 --- a/osl/preprocessing/osl_wrappers.py +++ b/osl/preprocessing/osl_wrappers.py @@ -10,8 +10,10 @@ import mne import numpy as np import sails +import yaml from os.path import exists from scipy import stats +from pathlib import Path logger = logging.getLogger(__name__) @@ -620,7 +622,84 @@ def drop_bad_epochs( # Wrapper functions +def run_osl_read_dataset(dataset, userargs): + """Reads ``fif``/``npy``/``yml`` files associated with a dataset. + Parameters + ---------- + fif : str + Path to raw fif file (can be preprocessed). + preload : bool + Should we load the raw fif data? + ftype : str + Extension for the fif file (will be replaced for e.g. ``'_events.npy'`` or + ``'_ica.fif'``). If ``None``, we assume the fif file is preprocessed with + OSL and has the extension ``'_preproc-raw'``. If this fails, we guess + the extension as whatever comes after the last ``'_'``. + + Returns + ------- + dataset : dict + Contains keys: ``'raw'``, ``'events'``, ``'event_id'``, ``'epochs'``, ``'ica'``. + """ + + logger.info("OSL Stage - {0}".format( "read_dataset")) + logger.info("userargs: {0}".format(str(userargs))) + ftype = userargs.pop("ftype", None) + + fif = dataset['raw'].filenames[0] + + # Guess extension + if ftype is None: + logger.info("Guessing the preproc extension") + if "preproc-raw" in fif: + logger.info('Assuming fif file type is "preproc-raw"') + ftype = "preproc-raw" + else: + if len(fif.split("_"))<2: + logger.error("Unable to guess the fif file extension") + else: + logger.info('Assuming fif file type is the last "_" separated string') + ftype = fif.split("_")[-1].split('.')[-2] + + # add extension to fif file name + ftype = ftype + ".fif" + + events = Path(fif.replace(ftype, "events.npy")) + if events.exists(): + print("Reading", events) + events = np.load(events) + else: + events = None + + event_id = Path(fif.replace(ftype, "event-id.yml")) + if event_id.exists(): + print("Reading", event_id) + with open(event_id, "r") as file: + event_id = yaml.load(file, Loader=yaml.Loader) + else: + event_id = None + + epochs = Path(fif.replace(ftype, "epo.fif")) + if epochs.exists(): + print("Reading", epochs) + epochs = mne.read_epochs(epochs) + else: + epochs = None + + ica = Path(fif.replace(ftype, "ica.fif")) + if ica.exists(): + print("Reading", ica) + ica = mne.preprocessing.read_ica(ica) + else: + ica = None + + dataset['event_id'] = event_id + dataset['events'] = events + dataset['ica'] = ica + dataset['epochs'] = epochs + + return dataset def run_osl_bad_segments(dataset, userargs): """OSL-Batch wrapper for :py:meth:`detect_badsegments `. From 350fcb4584d3624b76545cc39011cdff5655e782 Mon Sep 17 00:00:00 2001 From: matsvanes Date: Tue, 10 Sep 2024 10:05:18 +0100 Subject: [PATCH 03/13] save dataset['glm'] as osl glm model in write_dataset --- osl/preprocessing/batch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/osl/preprocessing/batch.py b/osl/preprocessing/batch.py index b5f3efb..48ebeb6 100644 --- a/osl/preprocessing/batch.py +++ b/osl/preprocessing/batch.py @@ -442,6 +442,9 @@ def write_dataset(dataset, outbase, run_id, ftype='preproc-raw', overwrite=False outname = outbase.format(run_id=run_id, ftype="tfr", fext="fif") dataset["tfr"].save(outname, overwrite=overwrite) + if "glm" in dataset and dataset['glm'] is not None: + outname = outbase.format(run_id=run_id, ftype="glm", fext="pkl") + dataset["glm"].save_pkl(outname, overwrite=overwrite) return fif_outname def read_dataset(fif, preload=False, ftype=None): From 38a5be711d57b67f1f1f9caa6ba87aa57c174be3 Mon Sep 17 00:00:00 2001 From: matsvanes Date: Tue, 10 Sep 2024 10:20:46 +0100 Subject: [PATCH 04/13] add Study.refresh() function --- osl/utils/study.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/osl/utils/study.py b/osl/utils/study.py index 48ca805..49976a3 100644 --- a/osl/utils/study.py +++ b/osl/utils/study.py @@ -100,6 +100,12 @@ def __init__(self, studydir): for d in self.match_values[1:]: self.fields[key].append(d[key]) + + def refresh(self): + """Refresh the study directory.""" + return self.__init__(self.studydir) + + def get(self, check_exist=True, **kwargs): """Get files from the study directory that match the fieldnames. From b1b2ecaa2f30b879f8176f945d4f7a7e3537707d Mon Sep 17 00:00:00 2001 From: matsvanes Date: Tue, 10 Sep 2024 10:51:38 +0100 Subject: [PATCH 05/13] fix missing import --- osl/glm/glm_epochs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/osl/glm/glm_epochs.py b/osl/glm/glm_epochs.py index c44038b..49c380e 100644 --- a/osl/glm/glm_epochs.py +++ b/osl/glm/glm_epochs.py @@ -2,6 +2,7 @@ import os import pickle from copy import deepcopy +from pathlib import Path import glmtools as glm import mne From 3c736272723b42d22efeb3e1c10e3bd4d2b972ea Mon Sep 17 00:00:00 2001 From: matsvanes Date: Tue, 10 Sep 2024 11:32:39 +0100 Subject: [PATCH 06/13] fix save_pkl --- osl/glm/glm_epochs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/osl/glm/glm_epochs.py b/osl/glm/glm_epochs.py index 49c380e..d95b4e6 100644 --- a/osl/glm/glm_epochs.py +++ b/osl/glm/glm_epochs.py @@ -52,7 +52,8 @@ def save_pkl(self, outname, overwrite=True, save_data=False): msg = "{} already exists. Please delete or do use overwrite=True." raise ValueError(msg.format(outname)) - self.config.detrend_func = None # Have to drop this to pickle + if hasattr(self, 'config'): + self.config.detrend_func = None # Have to drop this to pickle # This is hacky - but pickles are all or nothing and I don't know how # else to do it. HDF5 would be better longer term From e8588f69a5e9399f37d8bf68f18342e372c33d6e Mon Sep 17 00:00:00 2001 From: matsvanes Date: Tue, 10 Sep 2024 14:56:36 +0100 Subject: [PATCH 07/13] add save_pkl to group result --- osl/glm/glm_epochs.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/osl/glm/glm_epochs.py b/osl/glm/glm_epochs.py index d95b4e6..c0e431c 100644 --- a/osl/glm/glm_epochs.py +++ b/osl/glm/glm_epochs.py @@ -216,7 +216,40 @@ def get_fl_contrast(self, fl_con): ret_con.data = ret_con.data[:, fl_con, :, :] return ret_con + + def save_pkl(self, outname, overwrite=True, save_data=False): + """Save GLM-Epochs result to a pickle file. + + Parameters + ---------- + outname : str + Filename or full file path to write pickle to + overwrite : bool + Overwrite previous file if one exists? (Default value = True) + save_data : bool + Save epochs data in pickle? This is omitted by default to save disk + space (Default value = False) + """ + if Path(outname).exists() and not overwrite: + msg = "{} already exists. Please delete or do use overwrite=True." + raise ValueError(msg.format(outname)) + if hasattr(self, 'config'): + self.config.detrend_func = None # Have to drop this to pickle + + # This is hacky - but pickles are all or nothing and I don't know how + # else to do it. HDF5 would be better longer term + if save_data == False: + # Temporarily remove data before saving + dd = self.data + self.data = None + + with open(outname, 'bw') as outp: + pickle.dump(self, outp) + + # Put data back + if save_data == False: + self.data = dd #%% ------------------------------------------------------ From 69573745495794ab6b04bf7ac7244320f2a7506a Mon Sep 17 00:00:00 2001 From: Mats Date: Thu, 19 Sep 2024 12:12:04 +0100 Subject: [PATCH 08/13] implement group processing in preproc batch --- osl/preprocessing/batch.py | 309 +++++++++++++++++++++++++------------ 1 file changed, 211 insertions(+), 98 deletions(-) diff --git a/osl/preprocessing/batch.py b/osl/preprocessing/batch.py index 48ebeb6..5f2ad37 100644 --- a/osl/preprocessing/batch.py +++ b/osl/preprocessing/batch.py @@ -17,6 +17,7 @@ import traceback import re import logging +import pickle from pathlib import Path from copy import deepcopy from functools import partial, wraps @@ -254,33 +255,58 @@ def load_config(config): elif "versions" not in config['meta']: config["meta"]["versions"] = None - if "preproc" not in config: - raise KeyError("Please specify preprocessing steps in config.") + if "preproc" not in config and "group" not in config: + raise KeyError("Please specify preprocessing and/or group processing steps in config.") + + if "preproc" in config: + for stage in config["preproc"]: + # Check each stage is a dictionary with a single key + if not isinstance(stage, dict): + raise ValueError( + "Preprocessing stage '{0}' is a {1} not a dict".format( + stage, type(stage) + ) + ) - for stage in config["preproc"]: - # Check each stage is a dictionary with a single key - if not isinstance(stage, dict): - raise ValueError( - "Preprocessing stage '{0}' is a {1} not a dict".format( - stage, type(stage) + if len(stage) != 1: + raise ValueError( + "Preprocessing stage '{0}' should only have a single key".format(stage) ) - ) - if len(stage) != 1: - raise ValueError( - "Preprocessing stage '{0}' should only have a single key".format(stage) - ) + for key, val in stage.items(): + # internally we want options to be an empty dict (for now at least) + if val in ["null", "None", None]: + stage[key] = {} + + for step in config["preproc"]: + if config["meta"]["event_codes"] is None and "find_events" in step.values(): + raise KeyError( + "event_codes must be passed in config if we are finding events." + ) + else: + config['preproc'] = None + + if "group" in config: + for stage in config["group"]: + # Check each stage is a dictionary with a single key + if not isinstance(stage, dict): + raise ValueError( + "Group processing stage '{0}' is a {1} not a dict".format( + stage, type(stage) + ) + ) - for key, val in stage.items(): - # internally we want options to be an empty dict (for now at least) - if val in ["null", "None", None]: - stage[key] = {} + if len(stage) != 1: + raise ValueError( + "Group processing stage '{0}' should only have a single key".format(stage) + ) - for step in config["preproc"]: - if config["meta"]["event_codes"] is None and "find_events" in step.values(): - raise KeyError( - "event_codes must be passed in config if we are finding events." - ) + for key, val in stage.items(): + # internally we want options to be an empty dict (for now at least) + if val in ["null", "None", None]: + stage[key] = {} + else: + config['group'] = None return config @@ -386,7 +412,7 @@ def append_preproc_info(dataset, config, extra_funcs=None): return dataset -def write_dataset(dataset, outbase, run_id, ftype='preproc-raw', overwrite=False): +def write_dataset(dataset, outbase, run_id, ftype='preproc-raw', overwrite=False, skip=None): """Write preprocessed data to a file. Will write all keys in the dataset dict to disk with corresponding extensions. @@ -403,49 +429,76 @@ def write_dataset(dataset, outbase, run_id, ftype='preproc-raw', overwrite=False Extension for the fif file (default ``preproc-raw``) overwrite : bool Should we overwrite if the file already exists? + skip : list or None + List of keys to skip writing to disk. If None, we don't skip any keys. Output ------ fif_outname : str The saved fif file name """ + + if skip is None: + skip = [] + else: + [logger.info("Skip saving of dataset['{}']".format(key)) for key in skip] # Strip "_preproc-raw" or "_raw" from the run id for string in ["_preproc-raw", "_raw"]: if string in run_id: run_id = run_id.replace(string, "") + + if "raw" in skip: + outnames = {} + else: + outnames = {"raw": outbase.format(run_id=run_id, ftype=ftype, fext="fif")} + if Path(outnames["raw"]).exists() and not overwrite: + raise ValueError( + "{} already exists. Please delete or do use overwrite=True.".format(fif_outname) + ) + logger.info(f"Saving dataset['raw'] as {outnames['raw']}") + dataset["raw"].save(outnames['raw'], overwrite=overwrite) + + if "events" in dataset and "events" not in skip and dataset['events'] is not None: + outnames['events'] = outbase.format(run_id=run_id, ftype="events", fext="npy") + logger.info(f"Saving dataset['events'] as {outnames['events']}") + np.save(outnames['events'], dataset["events"]) + + if "event_id" in dataset and "event_id" not in skip and dataset['event_id'] is not None: + outnames['event_id'] = outbase.format(run_id=run_id, ftype="event-id", fext="yml") + logger.info(f"Saving dataset['event_id'] as {outnames['event_id']}") + yaml.dump(dataset["event_id"], open(outnames['event_id'], "w")) + + if "epochs" in dataset and "epochs" not in skip and dataset['epochs'] is not None: + outnames['epochs'] = outbase.format(run_id=run_id, ftype="epo", fext="fif") + logger.info(f"Saving dataset['epochs'] as {outnames['epochs']}") + dataset["epochs"].save(outnames['epochs'], overwrite=overwrite) + + if "ica" in dataset and "ica" not in skip and dataset['ica'] is not None: + outnames['ica'] = outbase.format(run_id=run_id, ftype="ica", fext="fif") + logger.info(f"Saving dataset['ica'] as {outnames['ica']}") + dataset["ica"].save(outnames['ica'], overwrite=overwrite) + + if "tfr" in dataset and "tfr" not in skip and dataset['tfr'] is not None: + outnames['tfr'] = outbase.format(run_id=run_id, ftype="tfr", fext="fif") + logger.info(f"Saving dataset['tfr'] as {outnames['tfr']}") + dataset["tfr"].save(outnames['tfr'], overwrite=overwrite) + + if "glm" in dataset and "glm" not in skip and dataset['glm'] is not None: + outnames['glm'] = outbase.format(run_id=run_id, ftype="glm", fext="pkl") + logger.info(f"Saving dataset['glm'] as {outnames['glm']}") + dataset["glm"].save_pkl(outnames['glm'], overwrite=overwrite) + + # save remaining keys as pickle files + for key in dataset: + if key not in outnames and key not in skip: + outnames[key] = outbase.format(run_id=run_id, ftype=key, fext="pkl") + logger.info(f"Saving dataset['{key}'] as {outnames[key]}") + if (not os.path.exists(outnames[key]) or overwrite) and key not in skip and dataset[key] is not None: + with open(outnames[key], "wb") as f: + pickle.dump(dataset[key], f) + return outnames - fif_outname = outbase.format(run_id=run_id, ftype=ftype, fext="fif") - if Path(fif_outname).exists() and not overwrite: - raise ValueError( - "{} already exists. Please delete or do use overwrite=True.".format(fif_outname) - ) - dataset["raw"].save(fif_outname, overwrite=overwrite) - - if "events" in dataset and dataset['events'] is not None: - outname = outbase.format(run_id=run_id, ftype="events", fext="npy") - np.save(outname, dataset["events"]) - - if "event_id" in dataset and dataset['event_id'] is not None: - outname = outbase.format(run_id=run_id, ftype="event-id", fext="yml") - yaml.dump(dataset["event_id"], open(outname, "w")) - - if "epochs" in dataset and dataset['epochs'] is not None: - outname = outbase.format(run_id=run_id, ftype="epo", fext="fif") - dataset["epochs"].save(outname, overwrite=overwrite) - - if "ica" in dataset and dataset['ica'] is not None: - outname = outbase.format(run_id=run_id, ftype="ica", fext="fif") - dataset["ica"].save(outname, overwrite=overwrite) - - if "tfr" in dataset and dataset['tfr'] is not None: - outname = outbase.format(run_id=run_id, ftype="tfr", fext="fif") - dataset["tfr"].save(outname, overwrite=overwrite) - - if "glm" in dataset and dataset['glm'] is not None: - outname = outbase.format(run_id=run_id, ftype="glm", fext="pkl") - dataset["glm"].save_pkl(outname, overwrite=overwrite) - return fif_outname def read_dataset(fif, preload=False, ftype=None): """Reads ``fif``/``npy``/``yml`` files associated with a dataset. @@ -573,11 +626,16 @@ def plot_preproc_flowchart( ax.set_xticks([]) ax.set_yticks([]) if title == None: - ax.set_title("OSL Preprocessing Recipe", fontsize=24) + ax.set_title("OSL Processing Recipe", fontsize=24) else: ax.set_title(title, fontsize=24) - - stage_height = 1 / (1 + len(config["preproc"])) + + tmp_h = 1 + if config["preproc"] is not None: + tmp_h += 1 + len(config["preproc"]) + if config["group"] is not None: + tmp_h += 1 + len(config["group"]) + stage_height = 1 / tmp_h box = dict(boxstyle="round", facecolor=stagecol, alpha=1, pad=0.3) startbox = dict(boxstyle="round", facecolor=startcol, alpha=1) @@ -587,12 +645,16 @@ def plot_preproc_flowchart( "weight": "normal", "size": 16, } - - stages = [{"input": ""}, *config["preproc"], {"output": ""}] + stages = [{"input": ""}] + if config['preproc'] is not None: + stages += [{"preproc": ""}, *config["preproc"]] + if config['group'] is not None: + stages += [{"group": ""}, *config["group"]] + stages.append({"output": ""}) stage_str = "$\\bf{{{0}}}$ {1}" ax.arrow( - 0.5, 1, 0.0, -1, fc="k", ec="k", head_width=0.045, + 0.5, 1, 0.0, -1+0.02, fc="k", ec="k", head_width=0.045, head_length=0.035, length_includes_head=True, ) @@ -600,7 +662,7 @@ def plot_preproc_flowchart( method, userargs = next(iter(stage.items())) method = method.replace("_", "\_") - if method in ["input", "output"]: + if method in ["input", "preproc", "group", "output"]: b = startbox else: b = box @@ -644,6 +706,7 @@ def run_proc_chain( ret_dataset=True, gen_report=None, overwrite=False, + skip_save=None, extra_funcs=None, random_seed='auto', verbose="INFO", @@ -673,6 +736,8 @@ def run_proc_chain( Should we generate a report? overwrite : bool Should we overwrite the output file if it already exists? + skip_save: list or None (default) + List of keys to skip writing to disk. If None, we don't skip any keys. extra_funcs : list User-defined functions. random_seed : 'auto' (default), int or None @@ -799,9 +864,9 @@ def run_proc_chain( # Add preprocessing info to dataset dict dataset = append_preproc_info(dataset, config, extra_funcs) - fif_outname = None + outnames = {"raw": None} if outdir is not None: - fif_outname = write_dataset(dataset, outbase, run_id, overwrite=overwrite) + outnames = write_dataset(dataset, outbase, run_id, overwrite=overwrite, skip=skip_save) # Generate report data if gen_report: @@ -811,12 +876,12 @@ def run_proc_chain( from ..report import gen_html_data, gen_html_page # avoids circular import logger.info("{0} : Generating Report".format(now)) - report_data_dir = validate_outdir(reportdir / Path(fif_outname).stem) + report_data_dir = validate_outdir(reportdir / Path(outnames["raw"]).stem) gen_html_data( dataset["raw"], report_data_dir, ica=dataset["ica"], - preproc_fif_filename=fif_outname, + preproc_fif_filename=outnames["raw"], logsdir=logsdir, run_id=run_id, ) @@ -843,9 +908,9 @@ def run_proc_chain( logger.error(traceback.print_tb(ex_traceback)) with open(logfile.replace(".log", ".error.log"), "w") as f: - f.write("OSL PREPROCESSING CHAIN failed at: {0}".format(now)) + f.write("OSL PREPROCESSING CHAIN FAILED AT: {0}".format(now)) f.write("\n") - f.write('Processing filed during stage : "{0}"'.format(method)) + f.write('Processing failed during stage : "{0}"'.format(method)) f.write(str(ex_type)) f.write("\n") f.write(str(ex_value)) @@ -858,17 +923,21 @@ def run_proc_chain( # variable type return {} else: + if 'group' in config: + return False, None return False now = strftime("%Y-%m-%d %H:%M:%S", localtime()) logger.info("{0} : Processing Complete".format(now)) - if fif_outname is not None: - logger.info("Output file is {}".format(fif_outname)) + if outnames["raw"] is not None: + logger.info("Output file is {}".format(outnames["raw"])) if ret_dataset: return dataset else: + if 'group' in config: + return True, outnames return True @@ -882,6 +951,7 @@ def run_proc_batch( reportdir=None, gen_report=True, overwrite=False, + skip_save=None, extra_funcs=None, random_seed='auto', verbose="INFO", @@ -915,6 +985,8 @@ def run_proc_batch( Should we generate a report? overwrite : bool Should we overwrite the output file if it exists? + skip_save: list or None (default) + List of keys to skip writing to disk. If None, we don't skip any keys. extra_funcs : list User-defined functions. random_seed : 'auto' (default), int or None @@ -1011,42 +1083,83 @@ def run_proc_batch( if strictrun: logger.info('User confirms input config') - # Create partial function with fixed options - pool_func = partial( - run_proc_chain, - outdir=outdir, - ftype=ftype, - logsdir=logsdir, - reportdir=reportdir, - ret_dataset=False, - gen_report=gen_report, - overwrite=overwrite, - extra_funcs=extra_funcs, - random_seed=random_seed, - ) + if config['preproc'] is not None: + # Create partial function with fixed options + pool_func = partial( + run_proc_chain, + outdir=outdir, + ftype=ftype, + logsdir=logsdir, + reportdir=reportdir, + ret_dataset=False, + gen_report=gen_report, + overwrite=overwrite, + skip_save=skip_save, + extra_funcs=extra_funcs, + random_seed=random_seed, + ) - # Loop through input files to generate arguments for run_proc_chain - args = [] - for infile, subject in zip(infiles, subjects): - args.append((config, infile, subject)) + # Loop through input files to generate arguments for run_proc_chain + args = [] + for infile, subject in zip(infiles, subjects): + args.append((config, infile, subject)) - # Actually run the processes - if dask_client: - proc_flags = dask_parallel_bag(pool_func, args) + # Actually run the processes + if dask_client: + proc_flags = dask_parallel_bag(pool_func, args) + else: + proc_flags = [pool_func(*aa) for aa in args] + + if isinstance(proc_flags[0], tuple): + group_inputs = [flag[1] for flag in proc_flags] + proc_flags = [flag[0] for flag in proc_flags] + + osl_logger.set_up(log_file=logfile, level=verbose, startup=False) + logger.info( + "Processed {0}/{1} files successfully".format( + np.sum(proc_flags), len(proc_flags) + ) + ) + + # Generate a report + if gen_report and len(infiles) > 0: + from ..report import preproc_report # avoids circular import + preproc_report.gen_html_page(reportdir) + else: - proc_flags = [pool_func(*aa) for aa in args] + group_inputs = [{"raw": infile} for infile in infiles] + proc_flags = [None for sub in infiles] + + osl_logger.set_up(log_file=logfile, level=verbose, startup=False) + logger.info("No preprocessing steps specified. Skipping preprocessing.") - osl_logger.set_up(log_file=logfile, level=verbose, startup=False) - logger.info( - "Processed {0}/{1} files successfully".format( - np.sum(proc_flags), len(proc_flags) + + # start group processing + if config['group'] is not None: + logger.info("Starting Group Processing") + logger.info( + "Valid input files {0}/{1}".format( + np.sum(proc_flags), len(proc_flags) + ) ) - ) + dataset = {} + skip_save=[] + for key in group_inputs[0]: + dataset[key] = [group_inputs[i][key] for i in range(len(group_inputs))] + skip_save.append(key) + for stage in deepcopy(config["group"]): + method, userargs = next(iter(stage.items())) + target = userargs.get("target", "raw") # Raw is default + # skip.append(stage if userargs.get("skip_save") is True else None) # skip saving this stage to disk + func = find_func(method, target=target, extra_funcs=extra_funcs) + # Actual function call + dataset = func(dataset, userargs) + outbase = os.path.join(outdir, "{ftype}.{fext}") + outnames = write_dataset(dataset, outbase, '', ftype='', overwrite=overwrite, skip=skip_save) - # Generate a report - if gen_report and len(infiles) > 0: + # rerun the summary report + if gen_report: from ..report import preproc_report # avoids circular import - preproc_report.gen_html_page(reportdir) if preproc_report.gen_html_summary(reportdir, logsdir): logger.info("******************************" + "*" * len(str(reportdir))) logger.info(f"* REMEMBER TO CHECK REPORT: {reportdir} *") From 973d9ef1d03b51439dbbe3d79cc8061237bffc1f Mon Sep 17 00:00:00 2001 From: Mats Date: Fri, 20 Sep 2024 10:20:31 +0100 Subject: [PATCH 09/13] fix bug in load_config --- osl/preprocessing/batch.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/osl/preprocessing/batch.py b/osl/preprocessing/batch.py index 5f2ad37..e875317 100644 --- a/osl/preprocessing/batch.py +++ b/osl/preprocessing/batch.py @@ -247,6 +247,11 @@ def load_config(config): # We have a string config = yaml.load(config, Loader=yaml.FullLoader) + # do some checks on the config + for key in config: + if config[key] == 'None': + config[key] = None + # Initialise missing values in config if "meta" not in config: config["meta"] = {"event_codes": None} @@ -258,7 +263,7 @@ def load_config(config): if "preproc" not in config and "group" not in config: raise KeyError("Please specify preprocessing and/or group processing steps in config.") - if "preproc" in config: + if "preproc" in config and config["preproc"] is not None: for stage in config["preproc"]: # Check each stage is a dictionary with a single key if not isinstance(stage, dict): @@ -286,7 +291,7 @@ def load_config(config): else: config['preproc'] = None - if "group" in config: + if "group" in config and config["group"] is not None: for stage in config["group"]: # Check each stage is a dictionary with a single key if not isinstance(stage, dict): From 38d75911b514942513d1afc1cbbae0144ba657fd Mon Sep 17 00:00:00 2001 From: Mats Date: Fri, 20 Sep 2024 12:39:07 +0100 Subject: [PATCH 10/13] enable loading extra keys in read_dataset wrapper --- osl/preprocessing/osl_wrappers.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/osl/preprocessing/osl_wrappers.py b/osl/preprocessing/osl_wrappers.py index 8710fe3..b6310e8 100644 --- a/osl/preprocessing/osl_wrappers.py +++ b/osl/preprocessing/osl_wrappers.py @@ -11,6 +11,7 @@ import numpy as np import sails import yaml +import pickle from os.path import exists from scipy import stats from pathlib import Path @@ -636,6 +637,10 @@ def run_osl_read_dataset(dataset, userargs): ``'_ica.fif'``). If ``None``, we assume the fif file is preprocessed with OSL and has the extension ``'_preproc-raw'``. If this fails, we guess the extension as whatever comes after the last ``'_'``. + extra_keys : str + Space separated list of extra keys to read in from the same directory as the fif file. + If no suffix is provided, it's assumed to be .pkl. e.g., 'glm' will read in '..._glm.pkl' + 'events.npy' will read in '..._events.npy'. Returns ------- @@ -646,6 +651,7 @@ def run_osl_read_dataset(dataset, userargs): logger.info("OSL Stage - {0}".format( "read_dataset")) logger.info("userargs: {0}".format(str(userargs))) ftype = userargs.pop("ftype", None) + extra_keys = userargs.pop("extra_keys", []).split(" ") fif = dataset['raw'].filenames[0] @@ -699,6 +705,22 @@ def run_osl_read_dataset(dataset, userargs): dataset['ica'] = ica dataset['epochs'] = epochs + if len(extra_keys)>0: + for key in extra_keys: + extra_file = Path(fif.replace(ftype, key)) + key = key.split(".")[0] + if '.' not in extra_file.name: + extra_file = extra_file.with_suffix('.pkl') + if extra_file.exists(): + print("Reading", extra_file) + if '.pkl' in extra_file.name: + with open(extra_file, 'rb') as outp: + dataset[key] = pickle.load(outp) + elif '.npy' in extra_file.name: + dataset[key] = np.load(extra_file) + elif '.yml' in extra_file.name: + with open(extra_file, 'r') as file: + dataset[key] = yaml.load(file, Loader=yaml.Loader) return dataset def run_osl_bad_segments(dataset, userargs): From 3fb27b904624e179bcde94a4c9ef7a4c876bd6b7 Mon Sep 17 00:00:00 2001 From: Mats Date: Fri, 20 Sep 2024 15:17:57 +0100 Subject: [PATCH 11/13] fix read_dataset wrapper --- osl/preprocessing/osl_wrappers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/osl/preprocessing/osl_wrappers.py b/osl/preprocessing/osl_wrappers.py index b6310e8..bf0eadc 100644 --- a/osl/preprocessing/osl_wrappers.py +++ b/osl/preprocessing/osl_wrappers.py @@ -651,7 +651,7 @@ def run_osl_read_dataset(dataset, userargs): logger.info("OSL Stage - {0}".format( "read_dataset")) logger.info("userargs: {0}".format(str(userargs))) ftype = userargs.pop("ftype", None) - extra_keys = userargs.pop("extra_keys", []).split(" ") + extra_keys = userargs.pop("extra_keys", []) fif = dataset['raw'].filenames[0] @@ -706,6 +706,7 @@ def run_osl_read_dataset(dataset, userargs): dataset['epochs'] = epochs if len(extra_keys)>0: + extra_keys = extra_keys.split(" ") for key in extra_keys: extra_file = Path(fif.replace(ftype, key)) key = key.split(".")[0] From 2ebd7032c1c8adc2fe1f909905ad4f2a7c1d4631 Mon Sep 17 00:00:00 2001 From: Mats Date: Fri, 20 Sep 2024 16:13:59 +0100 Subject: [PATCH 12/13] fix read_dataset --- osl/preprocessing/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/osl/preprocessing/batch.py b/osl/preprocessing/batch.py index e875317..d87fa4e 100644 --- a/osl/preprocessing/batch.py +++ b/osl/preprocessing/batch.py @@ -454,7 +454,7 @@ def write_dataset(dataset, outbase, run_id, ftype='preproc-raw', overwrite=False run_id = run_id.replace(string, "") if "raw" in skip: - outnames = {} + outnames = {"raw": None} else: outnames = {"raw": outbase.format(run_id=run_id, ftype=ftype, fext="fif")} if Path(outnames["raw"]).exists() and not overwrite: From 93ea547e0158cebfeec5e0946b5c15b2d9046ddb Mon Sep 17 00:00:00 2001 From: Mats Date: Fri, 20 Sep 2024 18:14:10 +0100 Subject: [PATCH 13/13] fix parcel joint plotting in glm epochs --- osl/glm/glm_epochs.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/osl/glm/glm_epochs.py b/osl/glm/glm_epochs.py index c0e431c..ecd0789 100644 --- a/osl/glm/glm_epochs.py +++ b/osl/glm/glm_epochs.py @@ -108,8 +108,18 @@ def plot_joint_contrast(self, contrast=0, metric='copes', title=None): if title is None: title = 'C {} : {}'.format(contrast, self.design.contrast_names[contrast]) - - evo.plot_joint(title=title) + + try: + evo.plot_joint(title=title) + except: + from .glm_spectrum import plot_joint_spectrum + import matplotlib.pyplot as plt + fig = plt.figure() + fig.subplots_adjust(top=0.8) + ax = plt.subplot(111) + plot_joint_spectrum(evo.times, evo.get_data().T, evo.info, title=title, ax=ax) + ax.child_axes[0].set_xlabel('Time (s)') + ax.child_axes[0].set_ylabel(metric) class GroupGLMEpochs(GroupGLMBaseResult): @@ -188,7 +198,18 @@ def plot_joint_contrast(self, gcontrast=0, fcontrast=0, metric='copes', title=No joint_args['ts_args'] = {'scalings': dict(eeg=1, grad=1, mag=1), 'units': dict(eeg='tstats', grad='tstats', mag='tstats')} - evo.plot_joint(title=title, **joint_args) + try: + evo.plot_joint(title=title, **joint_args) + except: + from .glm_spectrum import plot_joint_spectrum + import matplotlib.pyplot as plt + fig = plt.figure() + fig.subplots_adjust(top=0.8) + ax = plt.subplot(111) + plot_joint_spectrum(evo.times, evo.get_data().T, evo.info, title=title, **joint_args, ax=ax) + ax.child_axes[0].set_xlabel('Time (s)') + ax.child_axes[0].set_ylabel(metric) + def get_channel_adjacency(self): """Return adjacency matrix of channels."""