Skip to content

Commit

Permalink
set random seed by default in batch processing (preproc + source_recon)
Browse files Browse the repository at this point in the history
  • Loading branch information
matsvanes committed Sep 5, 2024
1 parent 6446af3 commit 03d3059
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 1 deletion.
28 changes: 27 additions & 1 deletion osl/preprocessing/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ..utils import logger as osl_logger
from ..utils.parallel import dask_parallel_bag
from ..utils.version_utils import check_version
from ..utils.misc import set_random_seed

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -641,6 +642,7 @@ def run_proc_chain(
gen_report=None,
overwrite=False,
extra_funcs=None,
random_seed='auto',
verbose="INFO",
mneverbose="WARNING",
):
Expand Down Expand Up @@ -670,6 +672,9 @@ def run_proc_chain(
Should we overwrite the output file if it already exists?
extra_funcs : list
User-defined functions.
random_seed : 'auto' (default), int or None
Random seed to set. If 'auto', a random seed will be generated. Random seeds are set for both Python and NumPy.
If None, no random seed is set.
verbose : str
Level of info to print.
Can be: ``'CRITICAL'``, ``'ERROR'``, ``'WARNING'``, ``'INFO'``, ``'DEBUG'`` or ``'NOTSET'``.
Expand Down Expand Up @@ -740,6 +745,14 @@ def run_proc_chain(
logger.info("{0} : Starting OSL Processing".format(now))
logger.info("input : {0}".format(infile))

# Set random seed
if random_seed == 'auto':
set_random_seed()
elif random_seed is None:
pass
else:
set_random_seed(random_seed)

# Write preprocessed data to output directory
if outdir is not None:
# Check for existing outputs - should be a .fif at least
Expand Down Expand Up @@ -867,6 +880,7 @@ def run_proc_batch(
gen_report=True,
overwrite=False,
extra_funcs=None,
random_seed='auto',
verbose="INFO",
mneverbose="WARNING",
strictrun=False,
Expand Down Expand Up @@ -898,6 +912,9 @@ def run_proc_batch(
Should we generate a report?
overwrite : bool
Should we overwrite the output file if it exists?
random_seed : 'auto' (default), int or None
Random seed to set. If 'auto', a random seed will be generated. Random seeds are set for both Python and NumPy.
If None, no random seed is set.
extra_funcs : list
User-defined functions.
verbose : str
Expand All @@ -924,7 +941,7 @@ def run_proc_batch(
>>> from dask.distributed import Client
>>> client = Client(threads_per_worker=1, n_workers=4)
"""

if outdir is None:
# Use the current working directory
outdir = os.getcwd()
Expand All @@ -944,6 +961,14 @@ def run_proc_batch(

logger.info('Starting OSL Batch Processing')

# Set random seed
if random_seed == 'auto':
random_seed = set_random_seed()
elif random_seed is None:
pass
else:
set_random_seed(random_seed)

# Check through inputs and parameters
infiles, good_files_outnames, good_files = process_file_inputs(files)

Expand Down Expand Up @@ -994,6 +1019,7 @@ def run_proc_batch(
gen_report=gen_report,
overwrite=overwrite,
extra_funcs=extra_funcs,
random_seed=random_seed,
)

# Loop through input files to generate arguments for run_proc_chain
Expand Down
26 changes: 26 additions & 0 deletions osl/source_recon/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..report import src_report
from ..utils import logger as osl_logger
from ..utils import validate_outdir, find_run_id, parallel
from ..utils.misc import set_random_seed

import logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -101,6 +102,7 @@ def run_src_chain(
verbose="INFO",
mneverbose="WARNING",
extra_funcs=None,
random_seed='auto',
):
"""Source reconstruction.
Expand Down Expand Up @@ -130,6 +132,9 @@ def run_src_chain(
Level of MNE verbose.
extra_funcs : list of functions
Custom functions.
random_seed : 'auto' (default), int or None
Random seed to set. If 'auto', a random seed will be generated. Random seeds are set for both Python and NumPy.
If None, no random seed is set.
Returns
-------
Expand Down Expand Up @@ -160,6 +165,14 @@ def run_src_chain(
logger.info("{0} : Starting OSL Processing".format(now))
logger.info("input : {0}".format(outdir / subject))

# Set random seed
if random_seed == 'auto':
set_random_seed()
elif random_seed is None:
pass
else:
set_random_seed(random_seed)

# Load config
if not isinstance(config, dict):
config = load_config(config)
Expand Down Expand Up @@ -251,6 +264,7 @@ def run_src_batch(
mneverbose="WARNING",
extra_funcs=None,
dask_client=False,
random_seed='auto',
):
"""Batch source reconstruction.
Expand Down Expand Up @@ -283,6 +297,9 @@ def run_src_batch(
Custom functions.
dask_client : bool
Are we using a dask client?
random_seed : 'auto' (default), int or None
Random seed to set. If 'auto', a random seed will be generated. Random seeds are set for both Python and NumPy.
If None, no random seed is set.
Returns
-------
Expand All @@ -302,6 +319,14 @@ def run_src_batch(
osl_logger.set_up(log_file=logfile, level=verbose, startup=False)
logger.info('Starting OSL Batch Source Reconstruction')

# Set random seed
if random_seed == 'auto':
random_seed = set_random_seed()
elif random_seed is None:
pass
else:
set_random_seed(random_seed)

# Load config
config = load_config(config)
config_str = pprint.PrettyPrinter().pformat(config)
Expand Down Expand Up @@ -377,6 +402,7 @@ def run_src_batch(
verbose=verbose,
mneverbose=mneverbose,
extra_funcs=extra_funcs,
random_seed=random_seed,
)

# Loop through input files to generate arguments for run_coreg_chain
Expand Down
30 changes: 30 additions & 0 deletions osl/utils/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Miscellaneous utility classes and functions.
"""

import logging
import random
import numpy as np


logger = logging.getLogger(__name__)


def set_random_seed(seed=None):
"""Set all random seeds.
This includes Python's random module and NumPy.
Parameters
----------
seed : int
Random seed.
"""
if seed is None:
seed = random.randint(0, 2**32 - 1)

logger.info(f"Setting random seed to {seed}")

random.seed(seed)
np.random.seed(seed)
return seed

0 comments on commit 03d3059

Please sign in to comment.