Skip to content

Commit

Permalink
Merge pull request #349 from OHBA-analysis/source_cols
Browse files Browse the repository at this point in the history
update source colors
  • Loading branch information
matsvanes authored Sep 24, 2024
2 parents 8d286a0 + c3e8e7f commit f7c52cc
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 45 deletions.
66 changes: 43 additions & 23 deletions osl/glm/glm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import glmtools as glm
import nibabel as nib
from nilearn.plotting import plot_glass_brain, plot_markers

from ..source_recon import parcellation

Expand Down Expand Up @@ -728,28 +727,43 @@ def plot_sensor_erp(xvect, erp, info, ax=None, sensor_proj=False,

if sensor_proj:
axins = ax.inset_axes([0.6, 0.6, 0.37, 0.37])
if np.any(['parcel' in ch for ch in info['ch_names']]):
parcellation_file = parcellation.guess_parcellation(erp.T)
colors = get_source_colors(parcellation_file)
cmap = ListedColormap(colors)
parc_centers = parcellation.parcel_centers(parcellation_file)
n_parcels = parc_centers.shape[0]
plot_markers(
np.arange(n_parcels),
parc_centers,
axes=axins,
node_size=6,
node_cmap=cmap,
annotate=False,
colorbar=False,
)
else:
plot_channel_layout(axins, info)
plot_sensor_proj(erp, info, ax=axins)

if title is not None:
ax.set_title(title)


def plot_sensor_proj(info, ax=None, cmap=None):
if ax is None:
fig = plt.figure()
ax = plt.subplot(111)
if np.any(['parcel' in ch for ch in info['ch_names']]):
parcellation_file = parcellation.guess_parcellation(len(info.ch_names))
parc_centers = parcellation.parcel_centers(parcellation_file)
if cmap is None:
cmap = 'viridis'
x, y, z = parc_centers.T
X = y
else:
colors = get_source_colors(parcellation_file)
cmap = ListedColormap(colors)
X = np.arange(n_parcels)

n_parcels = parc_centers.shape[0]
plot_markers(
X,
parc_centers,
axes=ax,
node_size=20,
node_cmap=cmap,
annotate=False,
colorbar=False,
)
else:
plot_channel_layout(ax, info)
return ax


def plot_sensor_data(xvect, data, info, ax=None, lw=0.5,
xticks=None, xticklabels=None,
sensor_cols=True, xtick_skip=1):
Expand Down Expand Up @@ -826,13 +840,19 @@ def decorate_spectrum(ax, ylabel='Amplitude'):
ax.set_ylabel(ylabel)


def get_source_colors(parcellation_file):
def get_source_colors(parcellation_file, cmap='viridis'):
parc_centers = stats.zscore(parcellation.parcel_centers(parcellation_file), axis=0)
x, y, z = parc_centers.T
ref = [-5, -5, -3]
colors = mne.viz.evoked._rgb(x, y, z)
order = [np.argsort(np.sqrt(ref[i] - parc_centers[:, i]) ** 2) for i in range(3)]
colors = np.vstack([colors[order[0],0], colors[order[1],1], colors[order[2],2]]).T
if cmap=='viridis':
cmap = plt.get_cmap('viridis')
norm = plt.Normalize(vmin=parc_centers.min(), vmax=parc_centers.max())
colors = cmap(norm(parc_centers))[:,1,:]
# colors = colors[np.argsort(y), :]
else:
ref = [-5, -5, -3]
colors = mne.viz.evoked._rgb(x, y, z)
order = [np.argsort(np.sqrt(ref[i] - parc_centers[:, i]) ** 2) for i in range(3)]
colors = np.vstack([colors[order[0],0], colors[order[1],1], colors[order[2],2]]).T
return colors


Expand Down
65 changes: 43 additions & 22 deletions osl/glm/glm_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,28 +1151,43 @@ def plot_sensor_spectrum(xvect, psd, info, ax=None, sensor_proj=False,

if sensor_proj:
axins = ax.inset_axes([0.6, 0.6, 0.37, 0.37])
if np.any(['parcel' in ch for ch in info['ch_names']]):
parcellation_file = parcellation.guess_parcellation(psd.T)
colors = get_source_colors(parcellation_file)
cmap = ListedColormap(colors)
parc_centers = parcellation.parcel_centers(parcellation_file)
n_parcels = parc_centers.shape[0]
plot_markers(
np.arange(n_parcels),
parc_centers,
axes=axins,
node_size=6,
node_cmap=cmap,
annotate=False,
colorbar=False,
)
else:
plot_channel_layout(axins, info)
plot_sensor_proj(info, ax=axins)

if title is not None:
ax.set_title(title)


def plot_sensor_proj(info, ax=None, cmap=None):
if ax is None:
fig = plt.figure()
ax = plt.subplot(111)
if np.any(['parcel' in ch for ch in info['ch_names']]):
parcellation_file = parcellation.guess_parcellation(len(info.ch_names))
parc_centers = parcellation.parcel_centers(parcellation_file)
if cmap is None:
cmap = 'viridis'
x, y, z = parc_centers.T
X = y
else:
colors = get_source_colors(parcellation_file)
cmap = ListedColormap(colors)
X = np.arange(n_parcels)

n_parcels = parc_centers.shape[0]
plot_markers(
X,
parc_centers,
axes=ax,
node_size=20,
node_cmap=cmap,
annotate=False,
colorbar=False,
)
else:
plot_channel_layout(ax, info)
return ax


def plot_sensor_data(xvect, data, info, ax=None, lw=0.5,
xticks=None, xticklabels=None,
sensor_cols=True, base=1, xtick_skip=1):
Expand Down Expand Up @@ -1241,13 +1256,19 @@ def prep_scaled_freq(base, freq_vect):
return fx, ftick, ftickscaled


def get_source_colors(parcellation_file):
def get_source_colors(parcellation_file, cmap='viridis'):
parc_centers = stats.zscore(parcellation.parcel_centers(parcellation_file), axis=0)
x, y, z = parc_centers.T
ref = [-5, -5, -3]
colors = mne.viz.evoked._rgb(x, y, z)
order = [np.argsort(np.sqrt(ref[i] - parc_centers[:, i]) ** 2) for i in range(3)]
colors = np.vstack([colors[order[0],0], colors[order[1],1], colors[order[2],2]]).T
if cmap=='viridis':
cmap = plt.get_cmap('viridis')
norm = plt.Normalize(vmin=parc_centers.min(), vmax=parc_centers.max())
colors = cmap(norm(parc_centers))[:,1,:]
# colors = colors[np.argsort(y), :]
else:
ref = [-5, -5, -3]
colors = mne.viz.evoked._rgb(x, y, z)
order = [np.argsort(np.sqrt(ref[i] - parc_centers[:, i]) ** 2) for i in range(3)]
colors = np.vstack([colors[order[0],0], colors[order[1],1], colors[order[2],2]]).T
return colors


Expand Down

0 comments on commit f7c52cc

Please sign in to comment.