Skip to content

Commit

Permalink
Merge branch 'main' into plot_source
Browse files Browse the repository at this point in the history
  • Loading branch information
cgohil committed Mar 5, 2024
2 parents 8902efe + a6eb595 commit 149f455
Show file tree
Hide file tree
Showing 16 changed files with 483 additions and 73 deletions.
59 changes: 59 additions & 0 deletions examples/spectrum_analysis_walkthrough.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import osl
from scipy import signal
import matplotlib.pyplot as plt

raw = osl.utils.simulate_raw_from_template(10000, noise=1/3)
raw.pick(picks='mag')


#%%
spec = osl.glm.glm_spectrum(raw)
spec.plot_joint_spectrum(freqs=(1, 10, 17), base=0.5, title='testing123')

#%%
aper, osc = osl.glm.glm_irasa(raw, mode='magnitude')
plt.figure()
ax = plt.subplot(121)
aper.plot_joint_spectrum(freqs=(1, 10, 17), base=0.5,ax=ax)
ax = plt.subplot(122)
osc.plot_joint_spectrum(freqs=(1, 10, 17), base=0.5,ax=ax)


#%%
alpha = raw.copy().filter(l_freq=7, h_freq=13)
covs = {'alpha': np.abs(signal.hilbert(alpha.get_data()[raw.ch_names.index('MEG1711'), :]))}

spec = osl.glm.glm_spectrum(raw, reg_ztrans=covs)

plt.figure()
ax = plt.subplot(121)
spec.plot_joint_spectrum(0, freqs=(1, 10, 17), base=0.5,ax=ax)
ax = plt.subplot(122)
spec.plot_joint_spectrum(1, freqs=(1, 10, 17), base=0.5,ax=ax)




aper, osc = osl.glm.glm_irasa(raw, reg_ztrans=covs)

plt.figure()
ax = plt.subplot(221)
aper.plot_joint_spectrum(0, freqs=(1, 10, 17), base=0.5,ax=ax)
ax = plt.subplot(222)
aper.plot_joint_spectrum(1, freqs=(1, 10, 17), base=0.5,ax=ax)
ax = plt.subplot(223)
osc.plot_joint_spectrum(0, freqs=(1, 10, 17), base=0.5,ax=ax)
ax = plt.subplot(224)
osc.plot_joint_spectrum(1, freqs=(1, 10, 17), base=0.5,ax=ax)




gglmsp = osl.glm.read_glm_spectrum('/Users/andrew/Downloads/bigmeg-camcan-movecomptrans_glm-spectrum_grad-noztrans_group-level.pkl')
spec = osl.glm.GroupSensorGLMSpectrum(gglmsp.model,
gglmsp.design,
gglmsp.config,
gglmsp.info,
fl_contrast_names=None,
data=gglmsp.data)
P = osl.glm.MaxStatPermuteGLMSpectrum(spec, 1, nperms=25)
260 changes: 207 additions & 53 deletions osl/glm/glm_spectrum.py

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions osl/preprocessing/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ..utils import find_run_id, validate_outdir, process_file_inputs, add_subdir
from ..utils import logger as osl_logger
from ..utils.parallel import dask_parallel_bag
from ..utils.version_utils import check_version

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -240,6 +241,8 @@ def load_config(config):
config["meta"] = {"event_codes": None}
elif "event_codes" not in config["meta"]:
config["meta"]["event_codes"] = None
elif "versions" not in config['meta']:
config["meta"]["versions"] = None

if "preproc" not in config:
raise KeyError("Please specify preprocessing steps in config.")
Expand Down Expand Up @@ -272,6 +275,36 @@ def load_config(config):
return config


def check_config_versions(config):
"""Get config from a preprocessed fif file.
Parameters
----------
config : dictionary or yaml string
Preprocessing configuration to check.
Raises
------
AssertionError
Raised if package version mismatch found in 'version_assert'
WARNING
Raised if package version mismatch found in 'version_warn'
"""
config = load_config(config)

# Check for version and raise an error if mismatch found
if 'version_assert' in config['meta']:
for vers in config['meta']['version_assert']:
check_version(vers, mode='assert')

# Check for version and raise a warning if mismatch found
if 'version_warn' in config['meta']:
for vers in config['meta']['version_warn']:
check_version(vers, mode='warn')


def get_config_from_fif(inst):
"""Get config from a preprocessed fif file.
Expand Down
6 changes: 5 additions & 1 deletion osl/report/preproc_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
from jinja2 import Template
from tabulate import tabulate
from mne.channels.channels import channel_type
from scipy.ndimage.filters import uniform_filter1d
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle
from pathlib import Path

try:
from scipy.ndimage import uniform_filter1d
except ImportError:
from scipy.ndimage.filters import uniform_filter1d

from ..utils import process_file_inputs, validate_outdir
from ..utils.logger import log_or_print
from ..preprocessing import (
Expand Down
2 changes: 1 addition & 1 deletion osl/report/src_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def save_extra_funcs(extra_funcs, reportdir):
Path to saved text file.
"""

if reportdir is not None:
if reportdir is not None and extra_funcs is not None:
fpath = reportdir / 'extra_funcs.txt'
with(open(fpath, 'w')) as file:
[print(f"{inspect.getsource(func)}\n\n", file=file) for func in extra_funcs]
Expand Down
13 changes: 10 additions & 3 deletions osl/source_recon/beamforming.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
import matplotlib.pyplot as plt
import mne
from mne import read_forward_solution, Covariance, compute_covariance, compute_raw_covariance
from mne.io.meas_info import _simplify_info
from mne.io.pick import pick_channels_cov, pick_info
from mne.io.proj import make_projector
from mne.rank import compute_rank
from mne.minimum_norm.inverse import _check_depth, _prepare_forward, _get_vertno
from mne.source_estimate import _get_src_type
Expand All @@ -39,6 +36,16 @@
)
from mne.utils import logger as mne_logger

try:
from mne._fiff.meas_info import _simplify_info
from mne._fiff.pick import pick_channels_cov, pick_info
from mne._fiff.proj import make_projector
except ImportError:
# Depreciated in mne 1.6
from mne.io.meas_info import _simplify_info
from mne.io.pick import pick_channels_cov, pick_info
from mne.io.proj import make_projector

from osl.source_recon import rhino
from osl.source_recon.rhino import utils as rhino_utils
from osl.utils.logger import log_or_print
Expand Down
17 changes: 14 additions & 3 deletions osl/source_recon/rhino/coreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,19 @@
from mne.viz.backends.renderer import _get_renderer
from mne.transforms import write_trans, read_trans, apply_trans, _get_trans, combine_transforms, Transform, rotation, invert_transform
from mne.forward import _create_meg_coils
from mne.io import _loc_to_coil_trans, read_info, read_raw, RawArray
from mne.io.pick import pick_types
from mne.io import read_info, read_raw, RawArray

try:
from mne import pick_types
except ImportError:
# Depreciated in mne 1.6
from mne.io.pick import pick_types

try:
from mne._fiff.tag import _loc_to_coil_trans
except ImportError:
# Depreciated in mne 1.6
from mne.io import _loc_to_coil_trans

from fsl import wrappers as fsl_wrappers

Expand Down Expand Up @@ -604,7 +615,7 @@ def coreg_display(

meg_picks = pick_types(info, meg=True, ref_meg=False, exclude=())

coil_transs = [_loc_to_coil_trans(info["chs"][pick]["loc"]) for pick in meg_picks ]
coil_transs = [_loc_to_coil_trans(info["chs"][pick]["loc"]) for pick in meg_picks]
coils = _create_meg_coils([info["chs"][pick] for pick in meg_picks], acc="normal")

meg_rrs, meg_tris, meg_sensor_locs, meg_sensor_oris = (list(), list(), list(), list())
Expand Down
7 changes: 6 additions & 1 deletion osl/source_recon/rhino/forward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from mne.io import read_info
from mne.io.constants import FIFF
from mne.surface import read_surface, write_surface
from mne.source_space import _make_volume_source_space, _complete_vol_src

try:
from mne.source_space import _make_volume_source_space, _complete_vol_src
except ImportError:
# Depreciated in mne 1.6
from mne.source_space._source_space import _make_volume_source_space, _complete_vol_src

import osl.source_recon.rhino.utils as rhino_utils
from osl.source_recon.rhino import get_coreg_filenames
Expand Down
23 changes: 23 additions & 0 deletions osl/tests/test_batch_preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,29 @@ def test_simple_chain(self):
assert(isinstance(dataset["raw"], mne.io.fiff.raw.Raw))


class TestVersions(unittest.TestCase):
def test_simple_chain(self):
from ..preprocessing import load_config, check_config_versions

cfg = """
meta:
event_codes:
version_assert:
version_warn:
preproc:
- filter: {l_freq: 1, h_freq: 30}
- notch_filter: {freqs: 50}
- bad_channels: {picks: 'grad'}
- bad_segments: {segment_len: 800, picks: 'grad'}
"""
config = load_config(cfg)

config['meta']['version_assert'] = ['numpy>1.0', 'scipy>1.0']
config['meta']['version_warn'] = ['mne>1.0']

check_config_versions(config)


class TestPreprocessingBatch(unittest.TestCase):

@classmethod
Expand Down
41 changes: 41 additions & 0 deletions osl/tests/test_glm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Tests for glm_spectrum and glm_epochs"""

import unittest
import tempfile
import os

import mne
import numpy as np


class TestGLMSpectrum(unittest.TestCase):

@classmethod
def setUpClass(cls):
from ..utils import simulate_raw_from_template

cls.flat_channels = None
cls.bad_channels = None
cls.bad_segments = None

cls.raw = simulate_raw_from_template(500,
flat_channels=cls.flat_channels,
bad_channels=cls.bad_channels,
bad_segments=cls.bad_segments)

cls.fpath = tempfile.NamedTemporaryFile().name + 'raw.fif'
cls.raw.save(cls.fpath)

@classmethod
def tearDownClass(cls):
os.remove(cls.fpath)

def test_glm_spectrum(self):
from ..glm import glm_spectrum

spec = glm_spectrum(self.raw)

def test_glm_irasa(self):
from ..glm import glm_irasa

aper, osc = glm_irasa(self.raw)
3 changes: 2 additions & 1 deletion osl/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from .simulate import * # noqa: F401, F403
from .opm import * # noqa: F401, F403
from .package import soft_import, run_package_tests # noqa: F401, F403
from .version_utils import check_version # noqa: F401, F403
from . import run_func # noqa: F401, F403

with open(os.path.join(os.path.dirname(__file__), "README.md"), 'r') as f:
__doc__ = f.read()
__doc__ = f.read()
4 changes: 2 additions & 2 deletions osl/utils/file_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def process_file_inputs(inputs):
for row in inputs:
infiles.append(sanitise_filepath(row[0]))
outnames.append(row[1])
elif isinstance(inputs[0], mne.io.fiff.raw.Raw):
elif isinstance(inputs[0], mne.io.Raw):
# We have a list of MNE objects
infiles = infiles
check_paths = False
Expand Down Expand Up @@ -125,7 +125,7 @@ def find_run_id(infile, preload=True):
# the fif option for everything except BTI scans? They're basically the
# same now.

if isinstance(infile, mne.io.fiff.raw.Raw):
if isinstance(infile, mne.io.Raw):
infile = infile.filenames[0]

if os.path.split(infile)[1] == 'c,rfDC':
Expand Down
10 changes: 7 additions & 3 deletions osl/utils/opm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@
import numpy as np

import mne
from mne.io import _coil_trans_to_loc
from mne.io.constants import FIFF
from mne.transforms import Transform, apply_trans

try:
from mne._fiff.tag import _coil_trans_to_loc
except ImportError:
# Depreciated in mne 1.6
from mne.io import _coil_trans_to_loc

import pandas as pd
import scipy

Expand Down Expand Up @@ -234,10 +239,9 @@ def correct_mri(smri_file, smri_fixed_file):

return sform_std

#########################################################################

# -------------------------------------------------------------
# %% Debug and plotting code for checking sensor locs and oris

if False:

from mne.io.pick import pick_types
Expand Down
12 changes: 8 additions & 4 deletions osl/utils/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np


def simulate_data(model, num_samples=1000, num_realisations=1, use_cov=True):
def simulate_data(model, num_samples=1000, num_realisations=1, use_cov=True, noise=None):
"""Simulate data from a linear model.
Parameters
Expand All @@ -31,7 +31,6 @@ def simulate_data(model, num_samples=1000, num_realisations=1, use_cov=True):
"""


num_sources = model.nsignals

# Preallocate output
Expand All @@ -50,10 +49,15 @@ def simulate_data(model, num_samples=1000, num_realisations=1, use_cov=True):
for t in range(model.order, num_samples):
for p in range(1, model.order):
Y[:, t, ep] -= -model.parameters[:, :, p].dot(Y[:, t-p, ep])

if noise is not None:
scale = Y.std()
Y += np.random.randn(*Y.shape) * (scale * noise)

return Y


def simulate_raw_from_template(sim_samples, bad_segments=None, bad_channels=None, flat_channels=None):
def simulate_raw_from_template(sim_samples, bad_segments=None, bad_channels=None, flat_channels=None, noise=None):
"""Simulate raw MEG data from a 306-channel MEGIN template.
Parameters
Expand Down Expand Up @@ -90,7 +94,7 @@ def simulate_raw_from_template(sim_samples, bad_segments=None, bad_channels=None
fname = 'reduced_mvar_pcacomp_{0}.npy'.format(mod)
pcacomp = np.load(os.path.join(basedir, fname))

Xsim = simulate_data(red_model, num_samples=sim_samples) * 2e-12
Xsim = simulate_data(red_model, num_samples=sim_samples, noise=noise) * 2e-12
Xsim = pcacomp.T.dot(Xsim[:,:,0])[:,:,None] # back to full space

Y[mne.pick_types(info, meg=mod), :] = Xsim[:, :, 0]
Expand Down
Loading

0 comments on commit 149f455

Please sign in to comment.