diff --git a/osl/glm/glm_base.py b/osl/glm/glm_base.py index b3040511..d4186c11 100644 --- a/osl/glm/glm_base.py +++ b/osl/glm/glm_base.py @@ -5,6 +5,8 @@ import glmtools as glm import mne import numpy as np +from scipy.sparse import csr_array +from ..source_recon.parcellation import spatial_dist_adjacency, guess_parcellation class GLMBaseResult: @@ -105,7 +107,7 @@ def __init__(self, model, design, info, config, fl_contrast_names=None, data=Non else: self.fl_contrast_names = fl_contrast_names - def get_channel_adjacency(self): + def get_channel_adjacency(self, dist=np.inf): """Return adjacency matrix of channels. Parameters @@ -119,8 +121,15 @@ def get_channel_adjacency(self): ch_names : list of str The channel names. """ - ch_type = mne.io.meas_info._get_channel_types(self.info)[0] # Assuming these are all the same! - adjacency, ch_names = mne.channels.channels._compute_ch_adjacency(self.info, ch_type) + if np.any(['parcel' in ch for ch in self.info['ch_names']]): + # We have parcellated data + parcellation_file = guess_parcellation(int(np.sum(['parcel' in ch for ch in self.info['ch_names']]))) + adjacency = csr_array(spatial_dist_adjacency(parcellation_file, dist=dist)) + elif np.any(['state' in ch for ch in self.info['ch_names']]) or np.any(['mode' in ch for ch in self.info['ch_names']]): + adjacency = csr_array(np.eye(len(self.info['ch_names']))) + else: + ch_type = mne.io.meas_info._get_channel_types(self.info)[0] # Assuming these are all the same! + adjacency, ch_names = mne.channels.channels._compute_ch_adjacency(self.info, ch_type) ntests = np.prod(self.data.data.shape[2:]) ntimes = self.data.data.shape[3] print('{} : {}'.format(ntimes, ntests)) diff --git a/osl/glm/glm_spectrum.py b/osl/glm/glm_spectrum.py index 1187ffa2..9fb0c83c 100644 --- a/osl/glm/glm_spectrum.py +++ b/osl/glm/glm_spectrum.py @@ -10,8 +10,14 @@ from sails.stft import glm_periodogram from scipy import signal, stats from .glm_base import GLMBaseResult, GroupGLMBaseResult, SensorClusterPerm, SensorMaxStatPerm - +from ..source_recon.parcellation import guess_parcellation, find_file, parcel_centers from matplotlib.patches import ConnectionPatch +from matplotlib.colors import ListedColormap + + # TODO: should replace with osl functions or make a soft import here +from osl_dynamics.analysis import power +import nibabel as nib +from nilearn.plotting import plot_glass_brain, plot_markers #%% --------------------------------------- # @@ -66,7 +72,6 @@ def plot_joint_spectrum(self, contrast=0, freqs='auto', base=1, ax=None, Proportion of plot dedicted to topomaps(Default value = 1/3) metric : {'copes' or 'tstats} Which metric to plot? (Default value = 'copes') - """ if metric == 'copes': spec = self.model.copes[contrast, :, :].T @@ -568,7 +573,6 @@ def plot_joint_spectrum_clusters(xvect, psd, clusters, info, ax=None, freqs='aut Number of xaxis ticks to skip, useful for tight plots (Default value = 1) topo_prop : float Proportion of plot dedicted to topomaps(Default value = 1/3) - """ if ax is None: fig = plt.figure() @@ -580,7 +584,7 @@ def plot_joint_spectrum_clusters(xvect, psd, clusters, info, ax=None, freqs='aut title_prop = 0.1 main_prop = 1-title_prop-topo_prop main_ax = ax.inset_axes((0, 0, 1, main_prop)) - + plot_sensor_spectrum(xvect, psd, info, ax=main_ax, base=base, lw=0.25, ylabel=ylabel) fx = prep_scaled_freq(base, xvect) @@ -675,7 +679,10 @@ def plot_joint_spectrum_clusters(xvect, psd, clusters, info, ax=None, freqs='aut # Plot topo dat = psd[fmid, :] - im, cn = mne.viz.plot_topomap(dat, info, axes=topo_ax, show=False, mask=channels, ch_type='planar1') + if np.any(['parcel' in ch for ch in info['ch_names']]): # source level data + im = plot_source_topo(dat, axis=topo_ax) + else: + im, cn = mne.viz.plot_topomap(dat, info, axes=topo_ax, show=False, mask=channels, ch_type='planar1') topos.append(im) if topo_scale == 'joint' and len(topos) > 0: @@ -694,6 +701,66 @@ def plot_joint_spectrum_clusters(xvect, psd, clusters, info, ax=None, freqs='aut ax.set_title(title, x=0.5, y=1-title_prop) +def plot_source_topo(data_map, parcellation_file=None, mask_file='MNI152_T1_8mm_brain.nii.gz', axis=None, cmap=None, vmin=None, vmax=None, alpha=0.7): + """Plot a data map on a cortical surface. Wrapper for nilearn.plotting.plot_glass_brain. + + Parameters + ---------- + data_map : array_like + Vector of data values to plot (nparc,) + parcellation_file : str + Filepath of parcellation file to plot data on + mask_file : str + Filepath of mask file to plot data on (Default value = 'MNI152_T1_8mm_brain.nii.gz') + axis : {None or axis handle} + Axis to plot into (Default value = None) + cmap : {None or matplotlib colormap} + Colormap to use for plotting (Default value = None) + vmin : {None or float} + Minimum value for colormap (Default value = None) + vmax : {None or float} + Maximum value for colormap (Default value = None) + alpha : {None or float} + Alpha value for colormap (Default value = None) + + Returns + ------- + image : :py:class:`matplotlib.image.AxesImage ` + AxesImage object + """ + + if parcellation_file is None: + parcellation_file = guess_parcellation(data_map) + parcellation_file = find_file(parcellation_file) + mask_file = find_file(mask_file) + + # prepare figure + if axis is None: + fig, axis = plt.subplots() + if cmap is None: + cmap = plt.cm.RdBu_r + + # prepare data + data_map = power.parcel_vector_to_voxel_grid(mask_file, parcellation_file, data_map) + mask = nib.load(mask_file) + nii = nib.Nifti1Image(data_map, mask.affine, mask.header) + + plot_glass_brain( + nii, + output_file=None, + display_mode='z', + colorbar=False, + axes=axis, + cmap=cmap, + alpha=alpha, + vmin=vmin, + vmax=vmax, + plot_abs=False, + annotate=False, + ) + return plt.gca().get_images()[0] + + def plot_joint_spectrum(xvect, psd, info, ax=None, freqs='auto', base=1, topo_scale='joint', lw=0.5, ylabel='Power', title='', ylim=None, xtick_skip=1, topo_prop=1/5): @@ -728,7 +795,16 @@ def plot_joint_spectrum(xvect, psd, info, ax=None, freqs='auto', base=1, Number of xaxis ticks to skip, useful for tight plots (Default value = 1) topo_prop : float Proportion of plot dedicted to topomaps(Default value = 1/3) - + source : bool + Whether the data is in source level (Default value = False) + parcellation_file : str + Filepath of parcellation file to plot data on (Default value = None) + + Notes + ----- + This function assumes the data are in MNE-Python format unless parcellation_file is specified. + If parcellation_file is not specified, and MNE-Python doesn't recognise the channel format, + the parcellation is guessed from the number of channels. """ if ax is None: fig = plt.figure() @@ -745,10 +821,23 @@ def plot_joint_spectrum(xvect, psd, info, ax=None, freqs='auto', base=1, fx = prep_scaled_freq(base, xvect) if freqs == 'auto': - topo_freq_inds = signal.find_peaks(np.abs(psd.mean(axis=1)), distance=xvect.shape[0]/3)[0] - if len(topo_freq_inds) > 2: - I = np.argsort(np.abs(psd.mean(axis=1))[topo_freq_inds])[-2:] - topo_freq_inds = topo_freq_inds[I] + if base == 1: + topo_freq_inds = signal.find_peaks(np.abs(psd.mean(axis=1)), distance=xvect.shape[0]/3)[0] + if len(topo_freq_inds) > 2: + I = np.argsort(np.abs(psd.mean(axis=1))[topo_freq_inds])[-2:] + topo_freq_inds = topo_freq_inds[I] + elif base == 0.5: # using distance is a bit tricky with sqrt freqs + dist = xvect.shape[0]/2.5 + tmp_topo_freq_inds = signal.find_peaks(np.abs(psd.mean(axis=1)))[0] + topo_freq_inds = [] + for i, ifrq in enumerate(tmp_topo_freq_inds): + if i==0: + topo_freq_inds.append(ifrq) + elif len(topo_freq_inds)==3: + continue + elif (np.argmin(np.abs(np.sqrt(ifrq)-fx[0])) - np.argmin(np.abs(np.sqrt(tmp_topo_freq_inds[i-1])-fx[0]))) < dist: + topo_freq_inds.append(ifrq) + topo_freq_inds = np.array(topo_freq_inds) freqs = xvect[topo_freq_inds] else: topo_freq_inds = [np.argmin(np.abs(xvect - ff)) for ff in freqs] @@ -785,7 +874,10 @@ def plot_joint_spectrum(xvect, psd, info, ax=None, freqs='auto', base=1, ax.figure.add_artist(con) dat = psd[topo_freq_inds[idx], :] - im, cn = mne.viz.plot_topomap(dat, info, axes=topo_ax, show=False) + if np.any(['parcel' in ch for ch in info['ch_names']]): # source data + im = plot_source_topo(dat, axis=topo_ax) + else: + im, cn = mne.viz.plot_topomap(dat, info, axes=topo_ax, show=False) topos.append(im) if topo_scale == 'joint': @@ -808,7 +900,8 @@ def plot_joint_spectrum(xvect, psd, info, ax=None, freqs='auto', base=1, def plot_sensor_spectrum(xvect, psd, info, ax=None, sensor_proj=False, xticks=None, xticklabels=None, lw=0.5, title=None, - sensor_cols=True, base=1, ylabel=None, xtick_skip=1): + sensor_cols=True, base=1, ylabel=None, xtick_skip=1, + source=False, parcellation_file=None): """Plot a GLM-Spectrum contrast with spatial line colouring. Parameters @@ -854,7 +947,15 @@ def plot_sensor_spectrum(xvect, psd, info, ax=None, sensor_proj=False, if sensor_proj: axins = ax.inset_axes([0.6, 0.6, 0.37, 0.37]) - plot_channel_layout(axins, info) + if np.any(['parcel' in ch for ch in info['ch_names']]): + parcellation_file = guess_parcellation(psd.T) + colors, order = get_source_cols(parcellation_file, return_order=True) + cmap = ListedColormap(colors) + parc_centers = parcel_centers(parcellation_file) + n_parcels = parc_centers.shape[0] + plot_markers(np.arange(n_parcels), parc_centers, axes=axins, node_size=10, node_cmap=cmap) + else: + plot_channel_layout(axins, info) if title is not None: ax.set_title(title) @@ -874,7 +975,13 @@ def plot_sensor_data(xvect, data, info, ax=None, lw=0.5, fx, xticklabels, xticks = prep_scaled_freq(base, xvect) if sensor_cols: - colors, pos, outlines = get_mne_sensor_cols(info) + if np.any(['parcel' in ch for ch in info['ch_names']]): + parcellation_file = guess_parcellation(data.T) + colors = get_source_cols(parcellation_file) + elif np.any(['state' in ch for ch in info['ch_names']]) or np.any(['mode' in ch for ch in info['ch_names']]): + colors = None + else: + colors, pos, outlines = get_mne_sensor_cols(info) else: colors = None @@ -889,8 +996,6 @@ def plot_sensor_data(xvect, data, info, ax=None, lw=0.5, def prep_scaled_freq(base, freq_vect): """ Prepare frequency vector for plotting with a given scaling. - - Parameters ---------- @@ -925,6 +1030,27 @@ def prep_scaled_freq(base, freq_vect): return fx, ftick, ftickscaled +def get_source_cols(parcellation_file, return_order=False): + parc_centers = stats.zscore(parcel_centers(parcellation_file), axis=0) + # parc_centers = parcel_centers(parcellation_file) + x = parc_centers[:, 0] + y = parc_centers[:, 1] + z = parc_centers[:, 2] + # Re-order to use colour to indicate anterior->posterior location + # ref = np.argsort(y)[-1] + # dist = np.sqrt((x[ref]-x)**2 + (y[ref]-y)**2 + (z[ref]-z)**2) + ref = [-5, -5, -3] + # dist = np.sqrt((ref[0]-x)**2 + (ref[1]-y)**2 + (ref[2]-z)**2) + # order = np.argsort(dist) + colors = mne.viz.evoked._rgb(x, y, z) + order = [np.argsort(np.sqrt(ref[i]-parc_centers[:,i])**2) for i in range(3)] + colors = np.vstack([colors[order[0],0], colors[order[1],1], colors[order[2],2]]).T + if return_order: + return colors, order + else: + return colors + + def get_mne_sensor_cols(info): """ Get sensor colours from MNE info object. @@ -966,7 +1092,7 @@ def plot_channel_layout(ax, info, size=30, marker='o'): info : :py:class:`mne.Info ` MNE-Python info object size : int - Size of sensor markers (Default value = 30) + Size of sensor ยง (Default value = 30) marker : str Marker type (Default value = 'o') diff --git a/osl/source_recon/parcellation/parcellation.py b/osl/source_recon/parcellation/parcellation.py index d3f1bdc5..347b4f72 100644 --- a/osl/source_recon/parcellation/parcellation.py +++ b/osl/source_recon/parcellation/parcellation.py @@ -63,6 +63,46 @@ def find_file(filename): return filename +def guess_parcellation(data, return_path=False): + """Guess parcellation file from data. + + Parameters + ---------- + data : vector or matrix + Data to guess parcellation from. first dimension is assumed to be parcels. + return_path : bool + If True, return path to parcellation file, otherwise return filename. + + returns + ------- + filename : str + Path to parcellation file. + """ + if type(data) is int: + nparc = data + else: + nparc = data.shape[0] + + # print('Guessing parcellation from data with {} parcels'.format(nparc)) + if nparc==52: + fname = "Glasser52_binary_space-MNI152NLin6_res-8x8x8.nii.gz" + elif nparc==50: + fname = "Glasser50_space-MNI152NLin6_res-8x8x8.nii.gz" + elif nparc==38: + fname = "fMRI_parcellation_ds8mm.nii.gz" + elif nparc==39: + fname = "fmri_d100_parcellation_with_PCC_tighterMay15_v2_8mm.nii.gz" + elif nparc==78: + fname = "aal_cortical_merged_8mm_stacked.nii.gz" + else: + raise ValueError("Can't guess parcellation for {} channels".format(nparc)) + # print('Guessing parcellation is {}'.format(fname)) + if return_path: + return find_file(fname) + else: + return find_file(fname).split('/')[-1] + + def parcellate_timeseries(parcellation_file, voxel_timeseries, voxel_coords, method, working_dir): """Parcellate a voxel time series.