From ed7141384ddf544fc83746e81a4a579da76356bd Mon Sep 17 00:00:00 2001 From: matsvanes Date: Mon, 9 Sep 2024 14:55:57 +0100 Subject: [PATCH] add read_dataset func --- osl/preprocessing/osl_wrappers.py | 79 +++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/osl/preprocessing/osl_wrappers.py b/osl/preprocessing/osl_wrappers.py index b31a909f..8710fe34 100644 --- a/osl/preprocessing/osl_wrappers.py +++ b/osl/preprocessing/osl_wrappers.py @@ -10,8 +10,10 @@ import mne import numpy as np import sails +import yaml from os.path import exists from scipy import stats +from pathlib import Path logger = logging.getLogger(__name__) @@ -620,7 +622,84 @@ def drop_bad_epochs( # Wrapper functions +def run_osl_read_dataset(dataset, userargs): + """Reads ``fif``/``npy``/``yml`` files associated with a dataset. + Parameters + ---------- + fif : str + Path to raw fif file (can be preprocessed). + preload : bool + Should we load the raw fif data? + ftype : str + Extension for the fif file (will be replaced for e.g. ``'_events.npy'`` or + ``'_ica.fif'``). If ``None``, we assume the fif file is preprocessed with + OSL and has the extension ``'_preproc-raw'``. If this fails, we guess + the extension as whatever comes after the last ``'_'``. + + Returns + ------- + dataset : dict + Contains keys: ``'raw'``, ``'events'``, ``'event_id'``, ``'epochs'``, ``'ica'``. + """ + + logger.info("OSL Stage - {0}".format( "read_dataset")) + logger.info("userargs: {0}".format(str(userargs))) + ftype = userargs.pop("ftype", None) + + fif = dataset['raw'].filenames[0] + + # Guess extension + if ftype is None: + logger.info("Guessing the preproc extension") + if "preproc-raw" in fif: + logger.info('Assuming fif file type is "preproc-raw"') + ftype = "preproc-raw" + else: + if len(fif.split("_"))<2: + logger.error("Unable to guess the fif file extension") + else: + logger.info('Assuming fif file type is the last "_" separated string') + ftype = fif.split("_")[-1].split('.')[-2] + + # add extension to fif file name + ftype = ftype + ".fif" + + events = Path(fif.replace(ftype, "events.npy")) + if events.exists(): + print("Reading", events) + events = np.load(events) + else: + events = None + + event_id = Path(fif.replace(ftype, "event-id.yml")) + if event_id.exists(): + print("Reading", event_id) + with open(event_id, "r") as file: + event_id = yaml.load(file, Loader=yaml.Loader) + else: + event_id = None + + epochs = Path(fif.replace(ftype, "epo.fif")) + if epochs.exists(): + print("Reading", epochs) + epochs = mne.read_epochs(epochs) + else: + epochs = None + + ica = Path(fif.replace(ftype, "ica.fif")) + if ica.exists(): + print("Reading", ica) + ica = mne.preprocessing.read_ica(ica) + else: + ica = None + + dataset['event_id'] = event_id + dataset['events'] = events + dataset['ica'] = ica + dataset['epochs'] = epochs + + return dataset def run_osl_bad_segments(dataset, userargs): """OSL-Batch wrapper for :py:meth:`detect_badsegments `.