Skip to content

Commit

Permalink
Updates to glmspectrum and plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
AJQuinn committed Mar 2, 2024
1 parent 6df02ff commit 6b3501d
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 8 deletions.
59 changes: 59 additions & 0 deletions examples/spectrum_analysis_walkthrough.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import osl
from scipy import signal
import matplotlib.pyplot as plt

raw = osl.utils.simulate_raw_from_template(10000, noise=1/3)
raw.pick(picks='mag')


#%%
spec = osl.glm.glm_spectrum(raw)
spec.plot_joint_spectrum(freqs=(1, 10, 17), base=0.5, title='testing123')

#%%
aper, osc = osl.glm.glm_irasa(raw, mode='magnitude')
plt.figure()
ax = plt.subplot(121)
aper.plot_joint_spectrum(freqs=(1, 10, 17), base=0.5,ax=ax)
ax = plt.subplot(122)
osc.plot_joint_spectrum(freqs=(1, 10, 17), base=0.5,ax=ax)


#%%
alpha = raw.copy().filter(l_freq=7, h_freq=13)
covs = {'alpha': np.abs(signal.hilbert(alpha.get_data()[raw.ch_names.index('MEG1711'), :]))}

spec = osl.glm.glm_spectrum(raw, reg_ztrans=covs)

plt.figure()
ax = plt.subplot(121)
spec.plot_joint_spectrum(0, freqs=(1, 10, 17), base=0.5,ax=ax)
ax = plt.subplot(122)
spec.plot_joint_spectrum(1, freqs=(1, 10, 17), base=0.5,ax=ax)




aper, osc = osl.glm.glm_irasa(raw, reg_ztrans=covs)

plt.figure()
ax = plt.subplot(221)
aper.plot_joint_spectrum(0, freqs=(1, 10, 17), base=0.5,ax=ax)
ax = plt.subplot(222)
aper.plot_joint_spectrum(1, freqs=(1, 10, 17), base=0.5,ax=ax)
ax = plt.subplot(223)
osc.plot_joint_spectrum(0, freqs=(1, 10, 17), base=0.5,ax=ax)
ax = plt.subplot(224)
osc.plot_joint_spectrum(1, freqs=(1, 10, 17), base=0.5,ax=ax)




gglmsp = osl.glm.read_glm_spectrum('/Users/andrew/Downloads/bigmeg-camcan-movecomptrans_glm-spectrum_grad-noztrans_group-level.pkl')
spec = osl.glm.GroupSensorGLMSpectrum(gglmsp.model,
gglmsp.design,
gglmsp.config,
gglmsp.info,
fl_contrast_names=None,
data=gglmsp.data)
P = osl.glm.MaxStatPermuteGLMSpectrum(spec, 1, nperms=25)
9 changes: 5 additions & 4 deletions osl/glm/glm_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def glm_spectrum(XX, reg_categorical=None, reg_ztrans=None, reg_unitmax=None,
contrasts=None, fit_intercept=True, standardise_data=False,
window_type='hann', nperseg=None, noverlap=None, nfft=None,
detrend='constant', return_onesided=True, scaling='density',
mode='psd', fmin=None, fmax=None, axis=-1, fs=1):
mode='psd', fmin=None, fmax=None, axis=-1, fs=1, verbose='WARNING'):
"""Compute a GLM-Spectrum from a MNE-Python Raw data object.
Parameters
Expand Down Expand Up @@ -520,7 +520,8 @@ def glm_spectrum(XX, reg_categorical=None, reg_ztrans=None, reg_unitmax=None,
scaling=scaling,
mode=mode,
fmin=fmin,
fmax=fmax)
fmax=fmax,
verbose=verbose)

if isinstance(XX, mne.io.base.BaseRaw):
return SensorGLMSpectrum(glmsp, XX.info)
Expand All @@ -533,7 +534,7 @@ def glm_irasa(XX, method='modified', resample_factors=None, aperiodic_average='m
contrasts=None, fit_intercept=True, standardise_data=False,
window_type='hann', nperseg=None, noverlap=None, nfft=None,
detrend='constant', return_onesided=True, scaling='density',
mode='psd', fmin=None, fmax=None, axis=-1, fs=1, verbose='INFO'):
mode='psd', fmin=None, fmax=None, axis=-1, fs=1, verbose='WARNING'):
"""Compute a GLM-IRASA from a MNE-Python Raw data object.
Parameters
Expand Down Expand Up @@ -738,7 +739,7 @@ def plot_joint_spectrum_clusters(xvect, psd, clusters, info, ax=None, freqs='aut
plot_sensor_spectrum(xvect, psd, info, ax=main_ax, base=base, lw=0.25, ylabel=ylabel)
fx = prep_scaled_freq(base, xvect)

yl = main_ax.get_ylim() if yl is None else yl
yl = main_ax.get_ylim() if ylim is None else ylim
yfactor = 1.2 if yl[1] > 0 else 0.8
main_ax.set_ylim(yl[0], yfactor*yl[1])

Expand Down
12 changes: 8 additions & 4 deletions osl/utils/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np


def simulate_data(model, num_samples=1000, num_realisations=1, use_cov=True):
def simulate_data(model, num_samples=1000, num_realisations=1, use_cov=True, noise=None):
"""Simulate data from a linear model.
Parameters
Expand All @@ -31,7 +31,6 @@ def simulate_data(model, num_samples=1000, num_realisations=1, use_cov=True):
"""


num_sources = model.nsignals

# Preallocate output
Expand All @@ -50,10 +49,15 @@ def simulate_data(model, num_samples=1000, num_realisations=1, use_cov=True):
for t in range(model.order, num_samples):
for p in range(1, model.order):
Y[:, t, ep] -= -model.parameters[:, :, p].dot(Y[:, t-p, ep])

if noise is not None:
scale = Y.std()
Y += np.random.randn(*Y.shape) * (scale * noise)

return Y


def simulate_raw_from_template(sim_samples, bad_segments=None, bad_channels=None, flat_channels=None):
def simulate_raw_from_template(sim_samples, bad_segments=None, bad_channels=None, flat_channels=None, noise=None):
"""Simulate raw MEG data from a 306-channel MEGIN template.
Parameters
Expand Down Expand Up @@ -90,7 +94,7 @@ def simulate_raw_from_template(sim_samples, bad_segments=None, bad_channels=None
fname = 'reduced_mvar_pcacomp_{0}.npy'.format(mod)
pcacomp = np.load(os.path.join(basedir, fname))

Xsim = simulate_data(red_model, num_samples=sim_samples) * 2e-12
Xsim = simulate_data(red_model, num_samples=sim_samples, noise=noise) * 2e-12
Xsim = pcacomp.T.dot(Xsim[:,:,0])[:,:,None] # back to full space

Y[mne.pick_types(info, meg=mod), :] = Xsim[:, :, 0]
Expand Down

0 comments on commit 6b3501d

Please sign in to comment.