Skip to content

Commit

Permalink
add in functionality to plot source/state level glmspectrum data.
Browse files Browse the repository at this point in the history
  • Loading branch information
matsvanes committed Jan 17, 2024
1 parent d02df75 commit a33f77b
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 20 deletions.
15 changes: 12 additions & 3 deletions osl/glm/glm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import glmtools as glm
import mne
import numpy as np
from scipy.sparse import csr_array
from ..source_recon.parcellation import spatial_dist_adjacency, guess_parcellation


class GLMBaseResult:
Expand Down Expand Up @@ -105,7 +107,7 @@ def __init__(self, model, design, info, config, fl_contrast_names=None, data=Non
else:
self.fl_contrast_names = fl_contrast_names

def get_channel_adjacency(self):
def get_channel_adjacency(self, dist=np.inf):
"""Return adjacency matrix of channels.
Parameters
Expand All @@ -119,8 +121,15 @@ def get_channel_adjacency(self):
ch_names : list of str
The channel names.
"""
ch_type = mne.io.meas_info._get_channel_types(self.info)[0] # Assuming these are all the same!
adjacency, ch_names = mne.channels.channels._compute_ch_adjacency(self.info, ch_type)
if np.any(['parcel' in ch for ch in self.info['ch_names']]):
# We have parcellated data
parcellation_file = guess_parcellation(int(np.sum(['parcel' in ch for ch in self.info['ch_names']])))
adjacency = csr_array(spatial_dist_adjacency(parcellation_file, dist=dist))
elif np.any(['state' in ch for ch in self.info['ch_names']]) or np.any(['mode' in ch for ch in self.info['ch_names']]):
adjacency = csr_array(np.eye(len(self.info['ch_names'])))
else:
ch_type = mne.io.meas_info._get_channel_types(self.info)[0] # Assuming these are all the same!
adjacency, ch_names = mne.channels.channels._compute_ch_adjacency(self.info, ch_type)
ntests = np.prod(self.data.data.shape[2:])
ntimes = self.data.data.shape[3]
print('{} : {}'.format(ntimes, ntests))
Expand Down
160 changes: 143 additions & 17 deletions osl/glm/glm_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,14 @@
from sails.stft import glm_periodogram
from scipy import signal, stats
from .glm_base import GLMBaseResult, GroupGLMBaseResult, SensorClusterPerm, SensorMaxStatPerm

from ..source_recon.parcellation import guess_parcellation, find_file, parcel_centers
from matplotlib.patches import ConnectionPatch
from matplotlib.colors import ListedColormap

# TODO: should replace with osl functions or make a soft import here
from osl_dynamics.analysis import power
import nibabel as nib
from nilearn.plotting import plot_glass_brain, plot_markers

#%% ---------------------------------------
#
Expand Down Expand Up @@ -66,7 +72,6 @@ def plot_joint_spectrum(self, contrast=0, freqs='auto', base=1, ax=None,
Proportion of plot dedicted to topomaps(Default value = 1/3)
metric : {'copes' or 'tstats}
Which metric to plot? (Default value = 'copes')
"""
if metric == 'copes':
spec = self.model.copes[contrast, :, :].T
Expand Down Expand Up @@ -568,7 +573,6 @@ def plot_joint_spectrum_clusters(xvect, psd, clusters, info, ax=None, freqs='aut
Number of xaxis ticks to skip, useful for tight plots (Default value = 1)
topo_prop : float
Proportion of plot dedicted to topomaps(Default value = 1/3)
"""
if ax is None:
fig = plt.figure()
Expand All @@ -580,7 +584,7 @@ def plot_joint_spectrum_clusters(xvect, psd, clusters, info, ax=None, freqs='aut
title_prop = 0.1
main_prop = 1-title_prop-topo_prop
main_ax = ax.inset_axes((0, 0, 1, main_prop))

plot_sensor_spectrum(xvect, psd, info, ax=main_ax, base=base, lw=0.25, ylabel=ylabel)
fx = prep_scaled_freq(base, xvect)

Expand Down Expand Up @@ -675,7 +679,10 @@ def plot_joint_spectrum_clusters(xvect, psd, clusters, info, ax=None, freqs='aut

# Plot topo
dat = psd[fmid, :]
im, cn = mne.viz.plot_topomap(dat, info, axes=topo_ax, show=False, mask=channels, ch_type='planar1')
if np.any(['parcel' in ch for ch in info['ch_names']]): # source level data
im = plot_source_topo(dat, axis=topo_ax)
else:
im, cn = mne.viz.plot_topomap(dat, info, axes=topo_ax, show=False, mask=channels, ch_type='planar1')
topos.append(im)

if topo_scale == 'joint' and len(topos) > 0:
Expand All @@ -694,6 +701,66 @@ def plot_joint_spectrum_clusters(xvect, psd, clusters, info, ax=None, freqs='aut
ax.set_title(title, x=0.5, y=1-title_prop)


def plot_source_topo(data_map, parcellation_file=None, mask_file='MNI152_T1_8mm_brain.nii.gz', axis=None, cmap=None, vmin=None, vmax=None, alpha=0.7):
"""Plot a data map on a cortical surface. Wrapper for nilearn.plotting.plot_glass_brain.
Parameters
----------
data_map : array_like
Vector of data values to plot (nparc,)
parcellation_file : str
Filepath of parcellation file to plot data on
mask_file : str
Filepath of mask file to plot data on (Default value = 'MNI152_T1_8mm_brain.nii.gz')
axis : {None or axis handle}
Axis to plot into (Default value = None)
cmap : {None or matplotlib colormap}
Colormap to use for plotting (Default value = None)
vmin : {None or float}
Minimum value for colormap (Default value = None)
vmax : {None or float}
Maximum value for colormap (Default value = None)
alpha : {None or float}
Alpha value for colormap (Default value = None)
Returns
-------
image : :py:class:`matplotlib.image.AxesImage <matplotlib.image.AxesImage>`
AxesImage object
"""

if parcellation_file is None:
parcellation_file = guess_parcellation(data_map)
parcellation_file = find_file(parcellation_file)
mask_file = find_file(mask_file)

# prepare figure
if axis is None:
fig, axis = plt.subplots()
if cmap is None:
cmap = plt.cm.RdBu_r

# prepare data
data_map = power.parcel_vector_to_voxel_grid(mask_file, parcellation_file, data_map)
mask = nib.load(mask_file)
nii = nib.Nifti1Image(data_map, mask.affine, mask.header)

plot_glass_brain(
nii,
output_file=None,
display_mode='z',
colorbar=False,
axes=axis,
cmap=cmap,
alpha=alpha,
vmin=vmin,
vmax=vmax,
plot_abs=False,
annotate=False,
)
return plt.gca().get_images()[0]


def plot_joint_spectrum(xvect, psd, info, ax=None, freqs='auto', base=1,
topo_scale='joint', lw=0.5, ylabel='Power', title='', ylim=None,
xtick_skip=1, topo_prop=1/5):
Expand Down Expand Up @@ -728,7 +795,16 @@ def plot_joint_spectrum(xvect, psd, info, ax=None, freqs='auto', base=1,
Number of xaxis ticks to skip, useful for tight plots (Default value = 1)
topo_prop : float
Proportion of plot dedicted to topomaps(Default value = 1/3)
source : bool
Whether the data is in source level (Default value = False)
parcellation_file : str
Filepath of parcellation file to plot data on (Default value = None)
Notes
-----
This function assumes the data are in MNE-Python format unless parcellation_file is specified.
If parcellation_file is not specified, and MNE-Python doesn't recognise the channel format,
the parcellation is guessed from the number of channels.
"""
if ax is None:
fig = plt.figure()
Expand All @@ -745,10 +821,23 @@ def plot_joint_spectrum(xvect, psd, info, ax=None, freqs='auto', base=1,
fx = prep_scaled_freq(base, xvect)

if freqs == 'auto':
topo_freq_inds = signal.find_peaks(np.abs(psd.mean(axis=1)), distance=xvect.shape[0]/3)[0]
if len(topo_freq_inds) > 2:
I = np.argsort(np.abs(psd.mean(axis=1))[topo_freq_inds])[-2:]
topo_freq_inds = topo_freq_inds[I]
if base == 1:
topo_freq_inds = signal.find_peaks(np.abs(psd.mean(axis=1)), distance=xvect.shape[0]/3)[0]
if len(topo_freq_inds) > 2:
I = np.argsort(np.abs(psd.mean(axis=1))[topo_freq_inds])[-2:]
topo_freq_inds = topo_freq_inds[I]
elif base == 0.5: # using distance is a bit tricky with sqrt freqs
dist = xvect.shape[0]/2.5
tmp_topo_freq_inds = signal.find_peaks(np.abs(psd.mean(axis=1)))[0]
topo_freq_inds = []
for i, ifrq in enumerate(tmp_topo_freq_inds):
if i==0:
topo_freq_inds.append(ifrq)
elif len(topo_freq_inds)==3:
continue
elif (np.argmin(np.abs(np.sqrt(ifrq)-fx[0])) - np.argmin(np.abs(np.sqrt(tmp_topo_freq_inds[i-1])-fx[0]))) < dist:
topo_freq_inds.append(ifrq)
topo_freq_inds = np.array(topo_freq_inds)
freqs = xvect[topo_freq_inds]
else:
topo_freq_inds = [np.argmin(np.abs(xvect - ff)) for ff in freqs]
Expand Down Expand Up @@ -785,7 +874,10 @@ def plot_joint_spectrum(xvect, psd, info, ax=None, freqs='auto', base=1,
ax.figure.add_artist(con)

dat = psd[topo_freq_inds[idx], :]
im, cn = mne.viz.plot_topomap(dat, info, axes=topo_ax, show=False)
if np.any(['parcel' in ch for ch in info['ch_names']]): # source data
im = plot_source_topo(dat, axis=topo_ax)
else:
im, cn = mne.viz.plot_topomap(dat, info, axes=topo_ax, show=False)
topos.append(im)

if topo_scale == 'joint':
Expand All @@ -808,7 +900,8 @@ def plot_joint_spectrum(xvect, psd, info, ax=None, freqs='auto', base=1,

def plot_sensor_spectrum(xvect, psd, info, ax=None, sensor_proj=False,
xticks=None, xticklabels=None, lw=0.5, title=None,
sensor_cols=True, base=1, ylabel=None, xtick_skip=1):
sensor_cols=True, base=1, ylabel=None, xtick_skip=1,
source=False, parcellation_file=None):
"""Plot a GLM-Spectrum contrast with spatial line colouring.
Parameters
Expand Down Expand Up @@ -854,7 +947,15 @@ def plot_sensor_spectrum(xvect, psd, info, ax=None, sensor_proj=False,

if sensor_proj:
axins = ax.inset_axes([0.6, 0.6, 0.37, 0.37])
plot_channel_layout(axins, info)
if np.any(['parcel' in ch for ch in info['ch_names']]):
parcellation_file = guess_parcellation(psd.T)
colors, order = get_source_cols(parcellation_file, return_order=True)
cmap = ListedColormap(colors)
parc_centers = parcel_centers(parcellation_file)
n_parcels = parc_centers.shape[0]
plot_markers(np.arange(n_parcels), parc_centers, axes=axins, node_size=10, node_cmap=cmap)
else:
plot_channel_layout(axins, info)

if title is not None:
ax.set_title(title)
Expand All @@ -874,7 +975,13 @@ def plot_sensor_data(xvect, data, info, ax=None, lw=0.5,
fx, xticklabels, xticks = prep_scaled_freq(base, xvect)

if sensor_cols:
colors, pos, outlines = get_mne_sensor_cols(info)
if np.any(['parcel' in ch for ch in info['ch_names']]):
parcellation_file = guess_parcellation(data.T)
colors = get_source_cols(parcellation_file)
elif np.any(['state' in ch for ch in info['ch_names']]) or np.any(['mode' in ch for ch in info['ch_names']]):
colors = None
else:
colors, pos, outlines = get_mne_sensor_cols(info)
else:
colors = None

Expand All @@ -889,8 +996,6 @@ def plot_sensor_data(xvect, data, info, ax=None, lw=0.5,

def prep_scaled_freq(base, freq_vect):
""" Prepare frequency vector for plotting with a given scaling.
Parameters
----------
Expand Down Expand Up @@ -925,6 +1030,27 @@ def prep_scaled_freq(base, freq_vect):
return fx, ftick, ftickscaled


def get_source_cols(parcellation_file, return_order=False):
parc_centers = stats.zscore(parcel_centers(parcellation_file), axis=0)
# parc_centers = parcel_centers(parcellation_file)
x = parc_centers[:, 0]
y = parc_centers[:, 1]
z = parc_centers[:, 2]
# Re-order to use colour to indicate anterior->posterior location
# ref = np.argsort(y)[-1]
# dist = np.sqrt((x[ref]-x)**2 + (y[ref]-y)**2 + (z[ref]-z)**2)
ref = [-5, -5, -3]
# dist = np.sqrt((ref[0]-x)**2 + (ref[1]-y)**2 + (ref[2]-z)**2)
# order = np.argsort(dist)
colors = mne.viz.evoked._rgb(x, y, z)
order = [np.argsort(np.sqrt(ref[i]-parc_centers[:,i])**2) for i in range(3)]
colors = np.vstack([colors[order[0],0], colors[order[1],1], colors[order[2],2]]).T
if return_order:
return colors, order
else:
return colors


def get_mne_sensor_cols(info):
""" Get sensor colours from MNE info object.
Expand Down Expand Up @@ -966,7 +1092,7 @@ def plot_channel_layout(ax, info, size=30, marker='o'):
info : :py:class:`mne.Info <mne.Info>`
MNE-Python info object
size : int
Size of sensor markers (Default value = 30)
Size of sensor § (Default value = 30)
marker : str
Marker type (Default value = 'o')
Expand Down
40 changes: 40 additions & 0 deletions osl/source_recon/parcellation/parcellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,46 @@ def find_file(filename):
return filename


def guess_parcellation(data, return_path=False):
"""Guess parcellation file from data.
Parameters
----------
data : vector or matrix
Data to guess parcellation from. first dimension is assumed to be parcels.
return_path : bool
If True, return path to parcellation file, otherwise return filename.
returns
-------
filename : str
Path to parcellation file.
"""
if type(data) is int:
nparc = data
else:
nparc = data.shape[0]

# print('Guessing parcellation from data with {} parcels'.format(nparc))
if nparc==52:
fname = "Glasser52_binary_space-MNI152NLin6_res-8x8x8.nii.gz"
elif nparc==50:
fname = "Glasser50_space-MNI152NLin6_res-8x8x8.nii.gz"
elif nparc==38:
fname = "fMRI_parcellation_ds8mm.nii.gz"
elif nparc==39:
fname = "fmri_d100_parcellation_with_PCC_tighterMay15_v2_8mm.nii.gz"
elif nparc==78:
fname = "aal_cortical_merged_8mm_stacked.nii.gz"
else:
raise ValueError("Can't guess parcellation for {} channels".format(nparc))
# print('Guessing parcellation is {}'.format(fname))
if return_path:
return find_file(fname)
else:
return find_file(fname).split('/')[-1]


def parcellate_timeseries(parcellation_file, voxel_timeseries, voxel_coords, method, working_dir):
"""Parcellate a voxel time series.
Expand Down

0 comments on commit a33f77b

Please sign in to comment.