From b2ae3f03d4d7ebdfbb6fa79376d71099be0cf697 Mon Sep 17 00:00:00 2001 From: AJQuinn Date: Mon, 4 Mar 2024 15:04:59 +0000 Subject: [PATCH] GLM Spectrum Update (#274) --- examples/spectrum_analysis_walkthrough.py | 59 +++++ osl/glm/glm_spectrum.py | 293 ++++++++++++++++------ osl/preprocessing/batch.py | 12 +- osl/tests/test_glm.py | 41 +++ osl/utils/simulate.py | 12 +- 5 files changed, 339 insertions(+), 78 deletions(-) create mode 100644 examples/spectrum_analysis_walkthrough.py create mode 100644 osl/tests/test_glm.py diff --git a/examples/spectrum_analysis_walkthrough.py b/examples/spectrum_analysis_walkthrough.py new file mode 100644 index 00000000..12aa266f --- /dev/null +++ b/examples/spectrum_analysis_walkthrough.py @@ -0,0 +1,59 @@ +import osl +from scipy import signal +import matplotlib.pyplot as plt + +raw = osl.utils.simulate_raw_from_template(10000, noise=1/3) +raw.pick(picks='mag') + + +#%% +spec = osl.glm.glm_spectrum(raw) +spec.plot_joint_spectrum(freqs=(1, 10, 17), base=0.5, title='testing123') + +#%% +aper, osc = osl.glm.glm_irasa(raw, mode='magnitude') +plt.figure() +ax = plt.subplot(121) +aper.plot_joint_spectrum(freqs=(1, 10, 17), base=0.5,ax=ax) +ax = plt.subplot(122) +osc.plot_joint_spectrum(freqs=(1, 10, 17), base=0.5,ax=ax) + + +#%% +alpha = raw.copy().filter(l_freq=7, h_freq=13) +covs = {'alpha': np.abs(signal.hilbert(alpha.get_data()[raw.ch_names.index('MEG1711'), :]))} + +spec = osl.glm.glm_spectrum(raw, reg_ztrans=covs) + +plt.figure() +ax = plt.subplot(121) +spec.plot_joint_spectrum(0, freqs=(1, 10, 17), base=0.5,ax=ax) +ax = plt.subplot(122) +spec.plot_joint_spectrum(1, freqs=(1, 10, 17), base=0.5,ax=ax) + + + + +aper, osc = osl.glm.glm_irasa(raw, reg_ztrans=covs) + +plt.figure() +ax = plt.subplot(221) +aper.plot_joint_spectrum(0, freqs=(1, 10, 17), base=0.5,ax=ax) +ax = plt.subplot(222) +aper.plot_joint_spectrum(1, freqs=(1, 10, 17), base=0.5,ax=ax) +ax = plt.subplot(223) +osc.plot_joint_spectrum(0, freqs=(1, 10, 17), base=0.5,ax=ax) +ax = plt.subplot(224) +osc.plot_joint_spectrum(1, freqs=(1, 10, 17), base=0.5,ax=ax) + + + + +gglmsp = osl.glm.read_glm_spectrum('/Users/andrew/Downloads/bigmeg-camcan-movecomptrans_glm-spectrum_grad-noztrans_group-level.pkl') +spec = osl.glm.GroupSensorGLMSpectrum(gglmsp.model, + gglmsp.design, + gglmsp.config, + gglmsp.info, + fl_contrast_names=None, + data=gglmsp.data) +P = osl.glm.MaxStatPermuteGLMSpectrum(spec, 1, nperms=25) diff --git a/osl/glm/glm_spectrum.py b/osl/glm/glm_spectrum.py index 1187ffa2..8a5bcc3e 100644 --- a/osl/glm/glm_spectrum.py +++ b/osl/glm/glm_spectrum.py @@ -2,12 +2,15 @@ import pickle from copy import deepcopy from pathlib import Path +from itertools import compress import glmtools as glm import matplotlib.pyplot as plt import mne import numpy as np from sails.stft import glm_periodogram +from sails.stft import glm_irasa as sails_glm_irasa + from scipy import signal, stats from .glm_base import GLMBaseResult, GroupGLMBaseResult, SensorClusterPerm, SensorMaxStatPerm @@ -34,9 +37,7 @@ def __init__(self, glmsp, info): self.config = glmsp.config super().__init__(glmsp.model, glmsp.design, info, data=glmsp.data) - def plot_joint_spectrum(self, contrast=0, freqs='auto', base=1, ax=None, - topo_scale='joint', lw=0.5, ylabel=None, title=None, - ylim=None, xtick_skip=1, topo_prop=1/5, metric='copes'): + def plot_joint_spectrum(self, contrast=0, metric='copes', **kwargs): """Plot a GLM-Spectrum contrast with spatial line colouring and topograpies. Parameters @@ -70,26 +71,23 @@ def plot_joint_spectrum(self, contrast=0, freqs='auto', base=1, ax=None, """ if metric == 'copes': spec = self.model.copes[contrast, :, :].T - ylabel = 'Power' if ylabel is None else ylabel + kwargs['ylabel'] = 'Power' if kwargs.get('ylabel') is None else kwargs.get('ylabel') elif metric == 'varcopes': spec = self.model.varcopes[contrast, :, :].T - ylabel = 'Standard-Error' if ylabel is None else ylabel + kwargs['ylabel'] = 'Varcopes' if kwargs.get('ylabel') is None else kwargs.get('ylabel') elif metric == 'tstats': spec = self.model.tstats[contrast, :, :].T - ylabel = 't-statistics' if ylabel is None else ylabel + kwargs['ylabel'] = 't-statistics' if kwargs.get('ylabel') is None else kwargs.get('ylabel') else: raise ValueError("Metric '{}' not recognised".format(metric)) - if title is None: - title = 'C {} : {}'.format(contrast, self.design.contrast_names[contrast]) + if kwargs.get('title') is None: + kwargs['title'] = 'C {} : {}'.format(contrast, self.design.contrast_names[contrast]) + + plot_joint_spectrum(self.f, spec, self.info, **kwargs) - plot_joint_spectrum(self.f, spec, self.info, freqs=freqs, base=base, - topo_scale=topo_scale, lw=lw, ylabel=ylabel, title=title, - ylim=ylim, xtick_skip=xtick_skip, topo_prop=topo_prop, ax=ax) - def plot_sensor_spectrum(self, contrast, sensor_proj=False, - xticks=None, xticklabels=None, lw=0.5, ax=None, title=None, - sensor_cols=True, base=1, ylabel=None, xtick_skip=1, metric='copes'): + def plot_sensor_spectrum(self, contrast=0, metric='copes', **kwargs): """Plot a GLM-Spectrum contrast with spatial line colouring. Parameters @@ -123,15 +121,20 @@ def plot_sensor_spectrum(self, contrast, sensor_proj=False, """ if metric == 'copes': spec = self.model.copes[contrast, :, :].T + kwargs['ylabel'] = 'Power' if kwargs.get('ylabel') is None else kwargs.get('ylabel') + elif metric == 'varcopes': + spec = self.model.varcopes[contrast, :, :].T + kwargs['ylabel'] = 'Varcopes' if kwargs.get('ylabel') is None else kwargs.get('ylabel') elif metric == 'tstats': spec = self.model.tstats[contrast, :, :].T + kwargs['ylabel'] = 't-statistics' if kwargs.get('ylabel') is None else kwargs.get('ylabel') + else: + raise ValueError("Metric '{}' not recognised".format(metric)) - if title is None: - title = 'C {} : {}'.format(contrast, self.design.contrast_names[contrast]) + if kwargs.get('title') is None: + kwargs['title'] = 'C {} : {}'.format(contrast, self.design.contrast_names[contrast]) - plot_sensor_spectrum(self.f, spec, self.info, ax=ax, sensor_proj=sensor_proj, - xticks=xticks, xticklabels=xticklabels, lw=lw, title=title, - sensor_cols=sensor_cols, base=base, ylabel=ylabel, xtick_skip=xtick_skip) + plot_sensor_spectrum(self.f, spec, self.info, **kwargs) class GroupSensorGLMSpectrum(GroupGLMBaseResult): @@ -207,9 +210,7 @@ def save_pkl(self, outname, overwrite=True, save_data=False): if save_data == False: self.data = dd - def plot_joint_spectrum(self, gcontrast=0, fcontrast=0, freqs='auto', base=1, ax=None, - topo_scale='joint', lw=0.5, ylabel='Power', title=None, - ylim=None, xtick_skip=1, topo_prop=1/5, metric='copes'): + def plot_joint_spectrum(self, gcontrast=0, fcontrast=0, metric='copes', **kwargs): """ Plot a GLM-Spectrum contrast with spatial line colouring and topograpies. Parameters @@ -245,23 +246,24 @@ def plot_joint_spectrum(self, gcontrast=0, fcontrast=0, freqs='auto', base=1, ax """ if metric == 'copes': spec = self.model.copes[gcontrast, fcontrast, :, :].T + kwargs['ylabel'] = 'Power' if kwargs.get('ylabel') is None else kwargs.get('ylabel') elif metric == 'varcopes': spec = self.model.varcopes[gcontrast, fcontrast, :, :].T - ylabel = 'Standard-Error' if ylabel is None else ylabel + kwargs['ylabel'] = 'Varcopes' if kwargs.get('ylabel') is None else kwargs.get('ylabel') elif metric == 'tstats': spec = self.model.tstats[gcontrast, fcontrast, :, :].T + kwargs['ylabel'] = 't-statistics' if kwargs.get('ylabel') is None else kwargs.get('ylabel') else: raise ValueError("Metric '{}' not recognised".format(metric)) - if title is None: - gtitle = 'gC {} : {}'.format(gcontrast, self.contrast_names[gcontrast]) - ftitle = 'flC {} : {}'.format(fcontrast, self.fl_contrast_names[fcontrast]) + if kwargs.get('title') is None: + gtitle = 'group con : {}'.format(self.contrast_names[gcontrast]) + ftitle = 'first-level con : {}'.format(self.fl_contrast_names[fcontrast]) + + kwargs['title'] = gtitle + '\n' + ftitle - title = gtitle + '\n' + ftitle + plot_joint_spectrum(self.f, spec, self.info, **kwargs) - plot_joint_spectrum(self.f, spec, self.info, freqs=freqs, base=base, - topo_scale=topo_scale, lw=lw, ylabel=ylabel, title=title, - ylim=ylim, xtick_skip=xtick_skip, topo_prop=topo_prop, ax=ax) def get_fl_contrast(self, fl_con): """Get the data from a single first level contrast. @@ -287,7 +289,7 @@ class MaxStatPermuteGLMSpectrum(SensorMaxStatPerm): """A class holding the result for sensor x frequency cluster stats computed from a group level GLM-Spectrum""" - def plot_sig_clusters(self, thresh, ax=None, base=1): + def plot_sig_clusters(self, thresh, ax=None, base=1, min_extent=1): """Plot the significant clusters at a given threshold. Parameters @@ -303,6 +305,11 @@ def plot_sig_clusters(self, thresh, ax=None, base=1): title = title.format(self.gl_contrast_name, self.fl_contrast_name) clu, obs = self.get_sig_clusters(thresh) + to_plot = [] + for c in clu: + to_plot.append(False if len(c[2][0]) < min_extent or len(c[2][1]) < min_extent else True) + clu = list(compress(clu, to_plot)) + plot_joint_spectrum_clusters(self.f, obs, clu, self.info, base=base, ax=ax, title=title, ylabel='t-stat') @@ -310,7 +317,7 @@ class ClusterPermuteGLMSpectrum(SensorClusterPerm): """A class holding the result for sensor x frequency cluster stats computed from a group level GLM-Spectrum""" - def plot_sig_clusters(self, thresh, ax=None, base=1): + def plot_sig_clusters(self, thresh, ax=None, base=1, min_extent=1): """Plot the significant clusters at a given threshold. Parameters @@ -327,6 +334,12 @@ def plot_sig_clusters(self, thresh, ax=None, base=1): title = title.format(self.gl_contrast_name, self.fl_contrast_name) clu, obs = self.perms.get_sig_clusters(thresh, self.perm_data) + + to_plot = [] + for c in clu: + to_plot.append(False if len(c[2][0]) < min_extent or len(c[2][1]) < min_extent else True) + clu = list(compress(clu, to_plot)) + plot_joint_spectrum_clusters(self.f, obs, clu, self.info, base=base, ax=ax, title=title, ylabel='t-stat') @@ -399,65 +412,72 @@ def glm_spectrum(XX, reg_categorical=None, reg_ztrans=None, reg_unitmax=None, contrasts=None, fit_intercept=True, standardise_data=False, window_type='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', - mode='psd', fmin=None, fmax=None, axis=-1, fs=1): + mode='psd', fmin=None, fmax=None, axis=-1, fs=1, verbose='WARNING'): """Compute a GLM-Spectrum from a MNE-Python Raw data object. Parameters ---------- XX : {MNE Raw object, or data array} Data to compute GLM-Spectrum from + standardise_data : bool + Flag indicating whether to z-transform input data (Default value = False) reg_categorical : dict or None Dictionary of covariate time series to be added as binary regessors. (Default value = None) reg_ztrans : dict or None Dictionary of covariate time series to be added as z-standardised regessors. (Default value = None) reg_unitmax : dict or None Dictionary of confound time series to be added as positive-valued unitmax regessors. (Default value = None) - contrasts : - (Default value = None) + contrasts : dict or None + Dictionary of contrasts to be computed in the model. + (Default value = None, will add a simple contrast for each regressor) fit_intercept : bool - Specifies whether a constant valued 'intercept' regressor is included in the model. (Default value = True) - standardise_data : bool - Flag indicating whether to z-transform input data (Default value = False) - window_type : - (Default value = 'hann') - nperseg : int, optional + Specifies whether a constant valued 'intercept' regressor is included in the model. (Default value = True)' + + nperseg : int Length of each segment. Defaults to None, but if window is str or tuple, is set to 256, and if window is array_like, is set to the length of the window. - noverlap : int, optional - Number of points to overlap between segments. If `None`, - ``noverlap = nperseg // 2``. Defaults to `None`. - nfft : int, optional - Length of the FFT used, if a zero padded FFT is desired. If - `None`, the FFT length is `nperseg`. Defaults to `None`. + noverlap : int + Number of samples that successive sliding windows should overlap. + window_type : str or tuple or array_like, optional + Desired window to use. If `window` is a string or tuple, it is + passed to `scipy.signal.windows.get_window` to generate the + window values, which are DFT-even by default. See `scipy.signal.windows` + for a list of windows and required parameters. + If `window` is array_like it will be used directly as the window and its + length must be nperseg. Defaults to a Hann window. detrend : str or function or `False`, optional Specifies how to detrend each segment. If `detrend` is a string, it is passed as the `type` argument to the `detrend` function. If it is a function, it takes a segment and returns a detrended segment. If `detrend` is `False`, no detrending is - done. Defaults to 'constant'. + done. Defaults to 'constant'.' + + nfft : int + Length of the FFT to use (Default value = 256) + axis : int + Axis of input array along which the computation is performed. (Default value = -1) return_onesided : bool, optional If `True`, return a one-sided spectrum for real data. If `False` return a two-sided spectrum. Defaults to `True`, but for complex data, a two-sided spectrum is always returned. - scaling : { 'density', 'spectrum' }, optional + mode : {'psd', 'magnitude', 'angle', 'phase', 'complex'} + Which type of spectrum to return (Default value = 'psd') + scaling : { 'density', 'spectrum' } Selects between computing the power spectral density ('density') where `Pxx` has units of V**2/Hz and computing the power spectrum ('spectrum') where `Pxx` has units of V**2, if `x` is measured in V and `fs` is measured in Hz. Defaults to 'density' - mode : - (Default value = 'psd') - fmin : float or None, optional - Minimum frequency value to return (Default value = 0) - fmax : float or None, optional - Maximum frequency value to return (Default value = 0.5) - axis : int - Axis to compute spectrum over, overridden if input is an MNE raw - object (Default value = -1) - fs : float, optional - Sampling frequency of the `x` time series. Defaults to 1.0. Overridden - by value in XX.info['sfreq'] if input is a MNE Raw object. + fs : float + Sampling rate of the data + fmin : {float, None} + Smallest frequency in desired range (left hand boundary) + fmax : {float, None} + Largest frequency in desired range (right hand boundary)' + + verbose : {None, 'DEBUG', 'INFO', 'WARNING', 'CRITICAL'} + String indicating the level of detail to be printed to the screen during computation.' Returns ------- @@ -501,8 +521,7 @@ def glm_spectrum(XX, reg_categorical=None, reg_ztrans=None, reg_unitmax=None, mode=mode, fmin=fmin, fmax=fmax, - ret_class=True, - fit_method='glmtools') + verbose=verbose) if isinstance(XX, mne.io.base.BaseRaw): return SensorGLMSpectrum(glmsp, XX.info) @@ -510,6 +529,140 @@ def glm_spectrum(XX, reg_categorical=None, reg_ztrans=None, reg_unitmax=None, return glmsp +def glm_irasa(XX, method='modified', resample_factors=None, aperiodic_average='median', + reg_categorical=None, reg_ztrans=None, reg_unitmax=None, + contrasts=None, fit_intercept=True, standardise_data=False, + window_type='hann', nperseg=None, noverlap=None, nfft=None, + detrend='constant', return_onesided=True, scaling='density', + mode='psd', fmin=None, fmax=None, axis=-1, fs=1, verbose='WARNING'): + """Compute a GLM-IRASA from a MNE-Python Raw data object. + + Parameters + ---------- + XX : {MNE Raw object, or data array} + Data to compute GLM-Spectrum from + standardise_data : bool + Flag indicating whether to z-transform input data (Default value = False) + reg_categorical : dict or None + Dictionary of covariate time series to be added as binary regessors. (Default value = None) + reg_ztrans : dict or None + Dictionary of covariate time series to be added as z-standardised regessors. (Default value = None) + reg_unitmax : dict or None + Dictionary of confound time series to be added as positive-valued unitmax regessors. (Default value = None) + contrasts : dict or None + Dictionary of contrasts to be computed in the model. + (Default value = None, will add a simple contrast for each regressor) + fit_intercept : bool + Specifies whether a constant valued 'intercept' regressor is included in the model. (Default value = True)' + + method : {'original', 'modified'} + whether to compute the original implementation of IRASA or the modified update + (default is 'modified') + resample_factors : {None, array_like} + array of resampling factors to average across or None, in which a set + of factors are automatically computed (default is None). + aperiodic_average : {'mean', 'median', 'median_bias', 'min'} + method for averaging across irregularly resampled spectra to estimate + the aperiodic component (default is 'median').' + + nperseg : int + Length of each segment. Defaults to None, but if window is str or + tuple, is set to 256, and if window is array_like, is set to the + length of the window. + noverlap : int + Number of samples that successive sliding windows should overlap. + window_type : str or tuple or array_like, optional + Desired window to use. If `window` is a string or tuple, it is + passed to `scipy.signal.windows.get_window` to generate the + window values, which are DFT-even by default. See `scipy.signal.windows` + for a list of windows and required parameters. + If `window` is array_like it will be used directly as the window and its + length must be nperseg. Defaults to a Hann window. + detrend : str or function or `False`, optional + Specifies how to detrend each segment. If `detrend` is a + string, it is passed as the `type` argument to the `detrend` + function. If it is a function, it takes a segment and returns a + detrended segment. If `detrend` is `False`, no detrending is + done. Defaults to 'constant'.' + + nfft : int + Length of the FFT to use (Default value = 256) + axis : int + Axis of input array along which the computation is performed. (Default value = -1) + return_onesided : bool, optional + If `True`, return a one-sided spectrum for real data. If + `False` return a two-sided spectrum. Defaults to `True`, but for + complex data, a two-sided spectrum is always returned. + mode : {'psd', 'magnitude', 'angle', 'phase', 'complex'} + Which type of spectrum to return (Default value = 'psd') + scaling : { 'density', 'spectrum' } + Selects between computing the power spectral density ('density') + where `Pxx` has units of V**2/Hz and computing the power + spectrum ('spectrum') where `Pxx` has units of V**2, if `x` + is measured in V and `fs` is measured in Hz. Defaults to + 'density' + fs : float + Sampling rate of the data + fmin : {float, None} + Smallest frequency in desired range (left hand boundary) + fmax : {float, None} + Largest frequency in desired range (right hand boundary)' + + verbose : {None, 'DEBUG', 'INFO', 'WARNING', 'CRITICAL'} + String indicating the level of detail to be printed to the screen during computation.' + + Returns + ------- + :py:class:`SensorGLMSpectrum ` + SensorGLMSpectrum instance containing the fitted GLM-Spectrum. + References + ---------- + .. [1] Quinn, A. J., Atkinson, L., Gohil, C., Kohl, O., Pitt, J., Zich, C., Nobre, + A. C., & Woolrich, M. W. (2022). The GLM-Spectrum: A multilevel framework + for spectrum analysis with covariate and confound modelling. Cold Spring + Harbor Laboratory. https://doi.org/10.1101/2022.11.14.516449 + + """ + if isinstance(XX, mne.io.base.BaseRaw): + fs = XX.info['sfreq'] + nperseg = int(np.floor(fs)) if nperseg is None else nperseg + YY = XX.get_data() + axis = 1 + else: + YY = XX + + if standardise_data: + YY = stats.zscore(YY, axis=axis) + + # sails.sftf.config freqvals isn't right when frange is trimmed! + aper, osc = sails_glm_irasa(YY, axis=axis, + method=method, + resample_factors=resample_factors, + aperiodic_average=aperiodic_average, + reg_categorical=reg_categorical, + reg_ztrans=reg_ztrans, + reg_unitmax=reg_unitmax, + contrasts=contrasts, + fit_intercept=fit_intercept, + window_type=window_type, + fs=fs, + nperseg=nperseg, + noverlap=noverlap, + nfft=nfft, + detrend=detrend, + return_onesided=return_onesided, + scaling=scaling, + mode=mode, + fmin=fmin, + fmax=fmax, + verbose=verbose) + + if isinstance(XX, mne.io.base.BaseRaw): + return SensorGLMSpectrum(aper, XX.info), SensorGLMSpectrum(osc, XX.info) + else: + return aper, osc + + def read_glm_spectrum(infile): """Read in a GLMSpectrum object that has been saved as as a pickle. @@ -534,7 +687,7 @@ def read_glm_spectrum(infile): def plot_joint_spectrum_clusters(xvect, psd, clusters, info, ax=None, freqs='auto', base=1, topo_scale='joint', lw=0.5, ylabel='Power', title='', ylim=None, - xtick_skip=1, topo_prop=1/5): + xtick_skip=1, topo_prop=1/5, topomap_args=None): """ Plot a GLM-Spectrum contrast from cluster objects, with spatial line colouring and topograpies. Parameters @@ -577,6 +730,8 @@ def plot_joint_spectrum_clusters(xvect, psd, clusters, info, ax=None, freqs='aut ax.set_axis_off() + topomap_args = {} if topomap_args is None else topomap_args + title_prop = 0.1 main_prop = 1-title_prop-topo_prop main_ax = ax.inset_axes((0, 0, 1, main_prop)) @@ -584,8 +739,9 @@ def plot_joint_spectrum_clusters(xvect, psd, clusters, info, ax=None, freqs='aut plot_sensor_spectrum(xvect, psd, info, ax=main_ax, base=base, lw=0.25, ylabel=ylabel) fx = prep_scaled_freq(base, xvect) - yl = main_ax.get_ylim() - main_ax.set_ylim(yl[0], 1.2*yl[1]) + yl = main_ax.get_ylim() if ylim is None else ylim + yfactor = 1.2 if yl[1] > 0 else 0.8 + main_ax.set_ylim(yl[0], yfactor*yl[1]) yt = ax.get_yticks() inds = yt < yl[1] @@ -675,7 +831,7 @@ def plot_joint_spectrum_clusters(xvect, psd, clusters, info, ax=None, freqs='aut # Plot topo dat = psd[fmid, :] - im, cn = mne.viz.plot_topomap(dat, info, axes=topo_ax, show=False, mask=channels, ch_type='planar1') + im, cn = mne.viz.plot_topomap(dat, info, axes=topo_ax, show=False, mask=channels, ch_type='planar1', **topomap_args) topos.append(im) if topo_scale == 'joint' and len(topos) > 0: @@ -889,8 +1045,6 @@ def plot_sensor_data(xvect, data, info, ax=None, lw=0.5, def prep_scaled_freq(base, freq_vect): """ Prepare frequency vector for plotting with a given scaling. - - Parameters ---------- @@ -908,7 +1062,6 @@ def prep_scaled_freq(base, freq_vect): ftickscaled : array_like Scaled frequency ticks - Notes ----- Assuming ephy freq ranges for now - around 1-40Hz diff --git a/osl/preprocessing/batch.py b/osl/preprocessing/batch.py index 09ef93b9..8b50da76 100644 --- a/osl/preprocessing/batch.py +++ b/osl/preprocessing/batch.py @@ -415,22 +415,26 @@ def write_dataset(dataset, outbase, run_id, ftype='preproc_raw', overwrite=False ) dataset["raw"].save(fif_outname, overwrite=overwrite) - if dataset["events"] is not None: + if "events" in dataset and dataset['events'] is not None: outname = outbase.format(run_id=run_id, ftype="events", fext="npy") np.save(outname, dataset["events"]) - if dataset["event_id"] is not None: + if "event_id" in dataset and dataset['event_id'] is not None: outname = outbase.format(run_id=run_id, ftype="event-id", fext="yml") yaml.dump(dataset["event_id"], open(outname, "w")) - if dataset["epochs"] is not None: + if "epochs" in dataset and dataset['epochs'] is not None: outname = outbase.format(run_id=run_id, ftype="epo", fext="fif") dataset["epochs"].save(outname, overwrite=overwrite) - if dataset["ica"] is not None: + if "ica" in dataset and dataset['ica'] is not None: outname = outbase.format(run_id=run_id, ftype="ica", fext="fif") dataset["ica"].save(outname, overwrite=overwrite) + if "tfr" in dataset and dataset['tfr'] is not None: + outname = outbase.format(run_id=run_id, ftype="tfr", fext="fif") + dataset["tfr"].save(outname, overwrite=overwrite) + return fif_outname def read_dataset(fif, preload=False, ftype=None): diff --git a/osl/tests/test_glm.py b/osl/tests/test_glm.py new file mode 100644 index 00000000..beeea079 --- /dev/null +++ b/osl/tests/test_glm.py @@ -0,0 +1,41 @@ +"""Tests for glm_spectrum and glm_epochs""" + +import unittest +import tempfile +import os + +import mne +import numpy as np + + +class TestGLMSpectrum(unittest.TestCase): + + @classmethod + def setUpClass(cls): + from ..utils import simulate_raw_from_template + + cls.flat_channels = None + cls.bad_channels = None + cls.bad_segments = None + + cls.raw = simulate_raw_from_template(500, + flat_channels=cls.flat_channels, + bad_channels=cls.bad_channels, + bad_segments=cls.bad_segments) + + cls.fpath = tempfile.NamedTemporaryFile().name + 'raw.fif' + cls.raw.save(cls.fpath) + + @classmethod + def tearDownClass(cls): + os.remove(cls.fpath) + + def test_glm_spectrum(self): + from ..glm import glm_spectrum + + spec = glm_spectrum(self.raw) + + def test_glm_irasa(self): + from ..glm import glm_irasa + + aper, osc = glm_irasa(self.raw) diff --git a/osl/utils/simulate.py b/osl/utils/simulate.py index 56f67728..71ed5b84 100644 --- a/osl/utils/simulate.py +++ b/osl/utils/simulate.py @@ -10,7 +10,7 @@ import numpy as np -def simulate_data(model, num_samples=1000, num_realisations=1, use_cov=True): +def simulate_data(model, num_samples=1000, num_realisations=1, use_cov=True, noise=None): """Simulate data from a linear model. Parameters @@ -31,7 +31,6 @@ def simulate_data(model, num_samples=1000, num_realisations=1, use_cov=True): """ - num_sources = model.nsignals # Preallocate output @@ -50,10 +49,15 @@ def simulate_data(model, num_samples=1000, num_realisations=1, use_cov=True): for t in range(model.order, num_samples): for p in range(1, model.order): Y[:, t, ep] -= -model.parameters[:, :, p].dot(Y[:, t-p, ep]) + + if noise is not None: + scale = Y.std() + Y += np.random.randn(*Y.shape) * (scale * noise) + return Y -def simulate_raw_from_template(sim_samples, bad_segments=None, bad_channels=None, flat_channels=None): +def simulate_raw_from_template(sim_samples, bad_segments=None, bad_channels=None, flat_channels=None, noise=None): """Simulate raw MEG data from a 306-channel MEGIN template. Parameters @@ -90,7 +94,7 @@ def simulate_raw_from_template(sim_samples, bad_segments=None, bad_channels=None fname = 'reduced_mvar_pcacomp_{0}.npy'.format(mod) pcacomp = np.load(os.path.join(basedir, fname)) - Xsim = simulate_data(red_model, num_samples=sim_samples) * 2e-12 + Xsim = simulate_data(red_model, num_samples=sim_samples, noise=noise) * 2e-12 Xsim = pcacomp.T.dot(Xsim[:,:,0])[:,:,None] # back to full space Y[mne.pick_types(info, meg=mod), :] = Xsim[:, :, 0]