Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/glm_module_updates' into plot_so…
Browse files Browse the repository at this point in the history
…urce
  • Loading branch information
cgohil8 committed Feb 26, 2024
2 parents b97b38f + 603b8f3 commit 8902efe
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 22 deletions.
41 changes: 23 additions & 18 deletions osl/glm/glm_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pickle
from copy import deepcopy
from pathlib import Path
from itertools import compress

import mne
import numpy as np
Expand Down Expand Up @@ -37,9 +38,7 @@ def __init__(self, glmsp, info):
self.config = glmsp.config
super().__init__(glmsp.model, glmsp.design, info, data=glmsp.data)

def plot_joint_spectrum(self, contrast=0, freqs='auto', base=1, ax=None,
topo_scale='joint', lw=0.5, ylabel=None, title=None,
ylim=None, xtick_skip=1, topo_prop=1/5, metric='copes'):
def plot_joint_spectrum(self, contrast=0, metric='copes', **kwargs):
"""Plot a GLM-Spectrum contrast with spatial line colouring and topograpies.
Parameters
Expand Down Expand Up @@ -72,22 +71,20 @@ def plot_joint_spectrum(self, contrast=0, freqs='auto', base=1, ax=None,
"""
if metric == 'copes':
spec = self.model.copes[contrast, :, :].T
ylabel = 'Power' if ylabel is None else ylabel
kwargs['ylabel'] = 'Power' if kwargs.get('ylabel') is None else kwargs.get('ylabel')
elif metric == 'varcopes':
spec = self.model.varcopes[contrast, :, :].T
ylabel = 'Standard-Error' if ylabel is None else ylabel
kwargs['ylabel'] = 'Varcopes' if kwargs.get('ylabel') is None else kwargs.get('ylabel')
elif metric == 'tstats':
spec = self.model.tstats[contrast, :, :].T
ylabel = 't-statistics' if ylabel is None else ylabel
kwargs['ylabel'] = 't-statistics' if kwargs.get('ylabel') is None else kwargs.get('ylabel')
else:
raise ValueError("Metric '{}' not recognised".format(metric))

if title is None:
title = 'C {} : {}'.format(contrast, self.design.contrast_names[contrast])
if kwargs.get('title') is None:
kwargs['title'] = 'C {} : {}'.format(contrast, self.design.contrast_names[contrast])

plot_joint_spectrum(self.f, spec, self.info, freqs=freqs, base=base,
topo_scale=topo_scale, lw=lw, ylabel=ylabel, title=title,
ylim=ylim, xtick_skip=xtick_skip, topo_prop=topo_prop, ax=ax)
plot_joint_spectrum(self.f, spec, self.info, **kwargs)

def plot_sensor_spectrum(self, contrast, sensor_proj=False,
xticks=None, xticklabels=None, lw=0.5, ax=None, title=None,
Expand Down Expand Up @@ -254,8 +251,8 @@ def plot_joint_spectrum(self, gcontrast=0, fcontrast=0, freqs='auto', base=1, ax
raise ValueError("Metric '{}' not recognised".format(metric))

if title is None:
gtitle = 'gC {} : {}'.format(gcontrast, self.contrast_names[gcontrast])
ftitle = 'flC {} : {}'.format(fcontrast, self.fl_contrast_names[fcontrast])
gtitle = 'group con : {}'.format(self.contrast_names[gcontrast])
ftitle = 'first-level con : {}'.format(self.fl_contrast_names[fcontrast])

title = gtitle + '\n' + ftitle

Expand Down Expand Up @@ -286,7 +283,7 @@ class MaxStatPermuteGLMSpectrum(SensorMaxStatPerm):
"""A class holding the result for sensor x frequency cluster stats computed
from a group level GLM-Spectrum"""

def plot_sig_clusters(self, thresh, ax=None, base=1):
def plot_sig_clusters(self, thresh, ax=None, base=1, min_extent=1):
"""Plot the significant clusters at a given threshold.
Parameters
Expand All @@ -302,6 +299,11 @@ def plot_sig_clusters(self, thresh, ax=None, base=1):
title = title.format(self.gl_contrast_name, self.fl_contrast_name)

clu, obs = self.get_sig_clusters(thresh)
to_plot = []
for c in clu:
to_plot.append(False if len(c[2][0]) < min_extent or len(c[2][1]) < min_extent else True)

clu = list(compress(clu, to_plot))
plot_joint_spectrum_clusters(self.f, obs, clu, self.info, base=base, ax=ax, title=title, ylabel='t-stat')


Expand Down Expand Up @@ -494,7 +496,6 @@ def glm_spectrum(XX, reg_categorical=None, reg_ztrans=None, reg_unitmax=None,
mode=mode,
fmin=fmin,
fmax=fmax,
ret_class=True,
fit_method='glmtools')

if isinstance(XX, mne.io.base.BaseRaw):
Expand Down Expand Up @@ -797,7 +798,7 @@ def plot_source_topo(

def plot_joint_spectrum(xvect, psd, info, ax=None, freqs='auto', base=1,
topo_scale='joint', lw=0.5, ylabel='Power', title='', ylim=None,
xtick_skip=1, topo_prop=1/5, topo_cmap=None):
xtick_skip=1, topo_prop=1/5, topo_cmap=None, topomap_args=None):
"""Plot a GLM-Spectrum contrast with spatial line colouring and topograpies.
Parameters
Expand Down Expand Up @@ -839,6 +840,8 @@ def plot_joint_spectrum(xvect, psd, info, ax=None, freqs='auto', base=1,
fig.subplots_adjust(top=0.8)
ax = plt.subplot(111)

topomap_args = {} if topomap_args is None else topomap_args

ax.set_axis_off()

title_prop = 0.1
Expand Down Expand Up @@ -870,8 +873,10 @@ def plot_joint_spectrum(xvect, psd, info, ax=None, freqs='auto', base=1,
else:
topo_freq_inds = [np.argmin(np.abs(xvect - ff)) for ff in freqs]

yl = main_ax.get_ylim()
main_ax.set_ylim(yl[0], 1.2*yl[1])
yl = main_ax.get_ylim() if ylim is None else ylim
yfactor = 1.2 if yl[1] > 0 else 0.8
main_ax.set_ylim(yl[0], yfactor*yl[1])
#yl = main_ax.get_ylim()

yt = ax.get_yticks()
inds = yt < yl[1]
Expand Down
12 changes: 8 additions & 4 deletions osl/preprocessing/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,22 +382,26 @@ def write_dataset(dataset, outbase, run_id, ftype='preproc_raw', overwrite=False
)
dataset["raw"].save(fif_outname, overwrite=overwrite)

if dataset["events"] is not None:
if "events" in dataset and dataset['events'] is not None:
outname = outbase.format(run_id=run_id, ftype="events", fext="npy")
np.save(outname, dataset["events"])

if dataset["event_id"] is not None:
if "event_id" in dataset and dataset['event_id'] is not None:
outname = outbase.format(run_id=run_id, ftype="event-id", fext="yml")
yaml.dump(dataset["event_id"], open(outname, "w"))

if dataset["epochs"] is not None:
if "epochs" in dataset and dataset['epochs'] is not None:
outname = outbase.format(run_id=run_id, ftype="epo", fext="fif")
dataset["epochs"].save(outname, overwrite=overwrite)

if dataset["ica"] is not None:
if "ica" in dataset and dataset['ica'] is not None:
outname = outbase.format(run_id=run_id, ftype="ica", fext="fif")
dataset["ica"].save(outname, overwrite=overwrite)

if "tfr" in dataset and dataset['tfr'] is not None:
outname = outbase.format(run_id=run_id, ftype="tfr", fext="fif")
dataset["tfr"].save(outname, overwrite=overwrite)

return fif_outname

def read_dataset(fif, preload=False, ftype=None):
Expand Down

0 comments on commit 8902efe

Please sign in to comment.