diff --git a/osl/preprocessing/batch.py b/osl/preprocessing/batch.py index 48ebeb6..5f2ad37 100644 --- a/osl/preprocessing/batch.py +++ b/osl/preprocessing/batch.py @@ -17,6 +17,7 @@ import traceback import re import logging +import pickle from pathlib import Path from copy import deepcopy from functools import partial, wraps @@ -254,33 +255,58 @@ def load_config(config): elif "versions" not in config['meta']: config["meta"]["versions"] = None - if "preproc" not in config: - raise KeyError("Please specify preprocessing steps in config.") + if "preproc" not in config and "group" not in config: + raise KeyError("Please specify preprocessing and/or group processing steps in config.") + + if "preproc" in config: + for stage in config["preproc"]: + # Check each stage is a dictionary with a single key + if not isinstance(stage, dict): + raise ValueError( + "Preprocessing stage '{0}' is a {1} not a dict".format( + stage, type(stage) + ) + ) - for stage in config["preproc"]: - # Check each stage is a dictionary with a single key - if not isinstance(stage, dict): - raise ValueError( - "Preprocessing stage '{0}' is a {1} not a dict".format( - stage, type(stage) + if len(stage) != 1: + raise ValueError( + "Preprocessing stage '{0}' should only have a single key".format(stage) ) - ) - if len(stage) != 1: - raise ValueError( - "Preprocessing stage '{0}' should only have a single key".format(stage) - ) + for key, val in stage.items(): + # internally we want options to be an empty dict (for now at least) + if val in ["null", "None", None]: + stage[key] = {} + + for step in config["preproc"]: + if config["meta"]["event_codes"] is None and "find_events" in step.values(): + raise KeyError( + "event_codes must be passed in config if we are finding events." + ) + else: + config['preproc'] = None + + if "group" in config: + for stage in config["group"]: + # Check each stage is a dictionary with a single key + if not isinstance(stage, dict): + raise ValueError( + "Group processing stage '{0}' is a {1} not a dict".format( + stage, type(stage) + ) + ) - for key, val in stage.items(): - # internally we want options to be an empty dict (for now at least) - if val in ["null", "None", None]: - stage[key] = {} + if len(stage) != 1: + raise ValueError( + "Group processing stage '{0}' should only have a single key".format(stage) + ) - for step in config["preproc"]: - if config["meta"]["event_codes"] is None and "find_events" in step.values(): - raise KeyError( - "event_codes must be passed in config if we are finding events." - ) + for key, val in stage.items(): + # internally we want options to be an empty dict (for now at least) + if val in ["null", "None", None]: + stage[key] = {} + else: + config['group'] = None return config @@ -386,7 +412,7 @@ def append_preproc_info(dataset, config, extra_funcs=None): return dataset -def write_dataset(dataset, outbase, run_id, ftype='preproc-raw', overwrite=False): +def write_dataset(dataset, outbase, run_id, ftype='preproc-raw', overwrite=False, skip=None): """Write preprocessed data to a file. Will write all keys in the dataset dict to disk with corresponding extensions. @@ -403,49 +429,76 @@ def write_dataset(dataset, outbase, run_id, ftype='preproc-raw', overwrite=False Extension for the fif file (default ``preproc-raw``) overwrite : bool Should we overwrite if the file already exists? + skip : list or None + List of keys to skip writing to disk. If None, we don't skip any keys. Output ------ fif_outname : str The saved fif file name """ + + if skip is None: + skip = [] + else: + [logger.info("Skip saving of dataset['{}']".format(key)) for key in skip] # Strip "_preproc-raw" or "_raw" from the run id for string in ["_preproc-raw", "_raw"]: if string in run_id: run_id = run_id.replace(string, "") + + if "raw" in skip: + outnames = {} + else: + outnames = {"raw": outbase.format(run_id=run_id, ftype=ftype, fext="fif")} + if Path(outnames["raw"]).exists() and not overwrite: + raise ValueError( + "{} already exists. Please delete or do use overwrite=True.".format(fif_outname) + ) + logger.info(f"Saving dataset['raw'] as {outnames['raw']}") + dataset["raw"].save(outnames['raw'], overwrite=overwrite) + + if "events" in dataset and "events" not in skip and dataset['events'] is not None: + outnames['events'] = outbase.format(run_id=run_id, ftype="events", fext="npy") + logger.info(f"Saving dataset['events'] as {outnames['events']}") + np.save(outnames['events'], dataset["events"]) + + if "event_id" in dataset and "event_id" not in skip and dataset['event_id'] is not None: + outnames['event_id'] = outbase.format(run_id=run_id, ftype="event-id", fext="yml") + logger.info(f"Saving dataset['event_id'] as {outnames['event_id']}") + yaml.dump(dataset["event_id"], open(outnames['event_id'], "w")) + + if "epochs" in dataset and "epochs" not in skip and dataset['epochs'] is not None: + outnames['epochs'] = outbase.format(run_id=run_id, ftype="epo", fext="fif") + logger.info(f"Saving dataset['epochs'] as {outnames['epochs']}") + dataset["epochs"].save(outnames['epochs'], overwrite=overwrite) + + if "ica" in dataset and "ica" not in skip and dataset['ica'] is not None: + outnames['ica'] = outbase.format(run_id=run_id, ftype="ica", fext="fif") + logger.info(f"Saving dataset['ica'] as {outnames['ica']}") + dataset["ica"].save(outnames['ica'], overwrite=overwrite) + + if "tfr" in dataset and "tfr" not in skip and dataset['tfr'] is not None: + outnames['tfr'] = outbase.format(run_id=run_id, ftype="tfr", fext="fif") + logger.info(f"Saving dataset['tfr'] as {outnames['tfr']}") + dataset["tfr"].save(outnames['tfr'], overwrite=overwrite) + + if "glm" in dataset and "glm" not in skip and dataset['glm'] is not None: + outnames['glm'] = outbase.format(run_id=run_id, ftype="glm", fext="pkl") + logger.info(f"Saving dataset['glm'] as {outnames['glm']}") + dataset["glm"].save_pkl(outnames['glm'], overwrite=overwrite) + + # save remaining keys as pickle files + for key in dataset: + if key not in outnames and key not in skip: + outnames[key] = outbase.format(run_id=run_id, ftype=key, fext="pkl") + logger.info(f"Saving dataset['{key}'] as {outnames[key]}") + if (not os.path.exists(outnames[key]) or overwrite) and key not in skip and dataset[key] is not None: + with open(outnames[key], "wb") as f: + pickle.dump(dataset[key], f) + return outnames - fif_outname = outbase.format(run_id=run_id, ftype=ftype, fext="fif") - if Path(fif_outname).exists() and not overwrite: - raise ValueError( - "{} already exists. Please delete or do use overwrite=True.".format(fif_outname) - ) - dataset["raw"].save(fif_outname, overwrite=overwrite) - - if "events" in dataset and dataset['events'] is not None: - outname = outbase.format(run_id=run_id, ftype="events", fext="npy") - np.save(outname, dataset["events"]) - - if "event_id" in dataset and dataset['event_id'] is not None: - outname = outbase.format(run_id=run_id, ftype="event-id", fext="yml") - yaml.dump(dataset["event_id"], open(outname, "w")) - - if "epochs" in dataset and dataset['epochs'] is not None: - outname = outbase.format(run_id=run_id, ftype="epo", fext="fif") - dataset["epochs"].save(outname, overwrite=overwrite) - - if "ica" in dataset and dataset['ica'] is not None: - outname = outbase.format(run_id=run_id, ftype="ica", fext="fif") - dataset["ica"].save(outname, overwrite=overwrite) - - if "tfr" in dataset and dataset['tfr'] is not None: - outname = outbase.format(run_id=run_id, ftype="tfr", fext="fif") - dataset["tfr"].save(outname, overwrite=overwrite) - - if "glm" in dataset and dataset['glm'] is not None: - outname = outbase.format(run_id=run_id, ftype="glm", fext="pkl") - dataset["glm"].save_pkl(outname, overwrite=overwrite) - return fif_outname def read_dataset(fif, preload=False, ftype=None): """Reads ``fif``/``npy``/``yml`` files associated with a dataset. @@ -573,11 +626,16 @@ def plot_preproc_flowchart( ax.set_xticks([]) ax.set_yticks([]) if title == None: - ax.set_title("OSL Preprocessing Recipe", fontsize=24) + ax.set_title("OSL Processing Recipe", fontsize=24) else: ax.set_title(title, fontsize=24) - - stage_height = 1 / (1 + len(config["preproc"])) + + tmp_h = 1 + if config["preproc"] is not None: + tmp_h += 1 + len(config["preproc"]) + if config["group"] is not None: + tmp_h += 1 + len(config["group"]) + stage_height = 1 / tmp_h box = dict(boxstyle="round", facecolor=stagecol, alpha=1, pad=0.3) startbox = dict(boxstyle="round", facecolor=startcol, alpha=1) @@ -587,12 +645,16 @@ def plot_preproc_flowchart( "weight": "normal", "size": 16, } - - stages = [{"input": ""}, *config["preproc"], {"output": ""}] + stages = [{"input": ""}] + if config['preproc'] is not None: + stages += [{"preproc": ""}, *config["preproc"]] + if config['group'] is not None: + stages += [{"group": ""}, *config["group"]] + stages.append({"output": ""}) stage_str = "$\\bf{{{0}}}$ {1}" ax.arrow( - 0.5, 1, 0.0, -1, fc="k", ec="k", head_width=0.045, + 0.5, 1, 0.0, -1+0.02, fc="k", ec="k", head_width=0.045, head_length=0.035, length_includes_head=True, ) @@ -600,7 +662,7 @@ def plot_preproc_flowchart( method, userargs = next(iter(stage.items())) method = method.replace("_", "\_") - if method in ["input", "output"]: + if method in ["input", "preproc", "group", "output"]: b = startbox else: b = box @@ -644,6 +706,7 @@ def run_proc_chain( ret_dataset=True, gen_report=None, overwrite=False, + skip_save=None, extra_funcs=None, random_seed='auto', verbose="INFO", @@ -673,6 +736,8 @@ def run_proc_chain( Should we generate a report? overwrite : bool Should we overwrite the output file if it already exists? + skip_save: list or None (default) + List of keys to skip writing to disk. If None, we don't skip any keys. extra_funcs : list User-defined functions. random_seed : 'auto' (default), int or None @@ -799,9 +864,9 @@ def run_proc_chain( # Add preprocessing info to dataset dict dataset = append_preproc_info(dataset, config, extra_funcs) - fif_outname = None + outnames = {"raw": None} if outdir is not None: - fif_outname = write_dataset(dataset, outbase, run_id, overwrite=overwrite) + outnames = write_dataset(dataset, outbase, run_id, overwrite=overwrite, skip=skip_save) # Generate report data if gen_report: @@ -811,12 +876,12 @@ def run_proc_chain( from ..report import gen_html_data, gen_html_page # avoids circular import logger.info("{0} : Generating Report".format(now)) - report_data_dir = validate_outdir(reportdir / Path(fif_outname).stem) + report_data_dir = validate_outdir(reportdir / Path(outnames["raw"]).stem) gen_html_data( dataset["raw"], report_data_dir, ica=dataset["ica"], - preproc_fif_filename=fif_outname, + preproc_fif_filename=outnames["raw"], logsdir=logsdir, run_id=run_id, ) @@ -843,9 +908,9 @@ def run_proc_chain( logger.error(traceback.print_tb(ex_traceback)) with open(logfile.replace(".log", ".error.log"), "w") as f: - f.write("OSL PREPROCESSING CHAIN failed at: {0}".format(now)) + f.write("OSL PREPROCESSING CHAIN FAILED AT: {0}".format(now)) f.write("\n") - f.write('Processing filed during stage : "{0}"'.format(method)) + f.write('Processing failed during stage : "{0}"'.format(method)) f.write(str(ex_type)) f.write("\n") f.write(str(ex_value)) @@ -858,17 +923,21 @@ def run_proc_chain( # variable type return {} else: + if 'group' in config: + return False, None return False now = strftime("%Y-%m-%d %H:%M:%S", localtime()) logger.info("{0} : Processing Complete".format(now)) - if fif_outname is not None: - logger.info("Output file is {}".format(fif_outname)) + if outnames["raw"] is not None: + logger.info("Output file is {}".format(outnames["raw"])) if ret_dataset: return dataset else: + if 'group' in config: + return True, outnames return True @@ -882,6 +951,7 @@ def run_proc_batch( reportdir=None, gen_report=True, overwrite=False, + skip_save=None, extra_funcs=None, random_seed='auto', verbose="INFO", @@ -915,6 +985,8 @@ def run_proc_batch( Should we generate a report? overwrite : bool Should we overwrite the output file if it exists? + skip_save: list or None (default) + List of keys to skip writing to disk. If None, we don't skip any keys. extra_funcs : list User-defined functions. random_seed : 'auto' (default), int or None @@ -1011,42 +1083,83 @@ def run_proc_batch( if strictrun: logger.info('User confirms input config') - # Create partial function with fixed options - pool_func = partial( - run_proc_chain, - outdir=outdir, - ftype=ftype, - logsdir=logsdir, - reportdir=reportdir, - ret_dataset=False, - gen_report=gen_report, - overwrite=overwrite, - extra_funcs=extra_funcs, - random_seed=random_seed, - ) + if config['preproc'] is not None: + # Create partial function with fixed options + pool_func = partial( + run_proc_chain, + outdir=outdir, + ftype=ftype, + logsdir=logsdir, + reportdir=reportdir, + ret_dataset=False, + gen_report=gen_report, + overwrite=overwrite, + skip_save=skip_save, + extra_funcs=extra_funcs, + random_seed=random_seed, + ) - # Loop through input files to generate arguments for run_proc_chain - args = [] - for infile, subject in zip(infiles, subjects): - args.append((config, infile, subject)) + # Loop through input files to generate arguments for run_proc_chain + args = [] + for infile, subject in zip(infiles, subjects): + args.append((config, infile, subject)) - # Actually run the processes - if dask_client: - proc_flags = dask_parallel_bag(pool_func, args) + # Actually run the processes + if dask_client: + proc_flags = dask_parallel_bag(pool_func, args) + else: + proc_flags = [pool_func(*aa) for aa in args] + + if isinstance(proc_flags[0], tuple): + group_inputs = [flag[1] for flag in proc_flags] + proc_flags = [flag[0] for flag in proc_flags] + + osl_logger.set_up(log_file=logfile, level=verbose, startup=False) + logger.info( + "Processed {0}/{1} files successfully".format( + np.sum(proc_flags), len(proc_flags) + ) + ) + + # Generate a report + if gen_report and len(infiles) > 0: + from ..report import preproc_report # avoids circular import + preproc_report.gen_html_page(reportdir) + else: - proc_flags = [pool_func(*aa) for aa in args] + group_inputs = [{"raw": infile} for infile in infiles] + proc_flags = [None for sub in infiles] + + osl_logger.set_up(log_file=logfile, level=verbose, startup=False) + logger.info("No preprocessing steps specified. Skipping preprocessing.") - osl_logger.set_up(log_file=logfile, level=verbose, startup=False) - logger.info( - "Processed {0}/{1} files successfully".format( - np.sum(proc_flags), len(proc_flags) + + # start group processing + if config['group'] is not None: + logger.info("Starting Group Processing") + logger.info( + "Valid input files {0}/{1}".format( + np.sum(proc_flags), len(proc_flags) + ) ) - ) + dataset = {} + skip_save=[] + for key in group_inputs[0]: + dataset[key] = [group_inputs[i][key] for i in range(len(group_inputs))] + skip_save.append(key) + for stage in deepcopy(config["group"]): + method, userargs = next(iter(stage.items())) + target = userargs.get("target", "raw") # Raw is default + # skip.append(stage if userargs.get("skip_save") is True else None) # skip saving this stage to disk + func = find_func(method, target=target, extra_funcs=extra_funcs) + # Actual function call + dataset = func(dataset, userargs) + outbase = os.path.join(outdir, "{ftype}.{fext}") + outnames = write_dataset(dataset, outbase, '', ftype='', overwrite=overwrite, skip=skip_save) - # Generate a report - if gen_report and len(infiles) > 0: + # rerun the summary report + if gen_report: from ..report import preproc_report # avoids circular import - preproc_report.gen_html_page(reportdir) if preproc_report.gen_html_summary(reportdir, logsdir): logger.info("******************************" + "*" * len(str(reportdir))) logger.info(f"* REMEMBER TO CHECK REPORT: {reportdir} *")