Skip to content

Commit

Permalink
Merge pull request #339 from OHBA-analysis/ica_label
Browse files Browse the repository at this point in the history
always keep ecg/eog on top in ica labeling
  • Loading branch information
matsvanes authored Sep 4, 2024
2 parents 8daf84c + 213e21c commit 6446af3
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 31 deletions.
5 changes: 5 additions & 0 deletions osl/preprocessing/ica_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ def ica_label(data_dir, subject, reject=None, interactive=True):
logger.info("Not removing any components from the data")

logger.info("Saving ICA data")

# make sure the format is correct, otherwise errors will occur
for key in ica.labels_.keys():
ica.labels_[key] = list(ica.labels_[key])

ica.save(ica_file, overwrite=True)

if reject is not None:
Expand Down
109 changes: 78 additions & 31 deletions osl/preprocessing/plot_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,19 +187,27 @@ def _plot_sources(
# add EOG/ECG channels if present
eog_chs = pick_types(inst.info, meg=False, eog=True, ref_meg=False)
ecg_chs = pick_types(inst.info, meg=False, ecg=True, ref_meg=False)
for eog_idx in eog_chs:
ch_names.append(inst.ch_names[eog_idx])
ch_types.append("eog")
for ecg_idx in ecg_chs:
ch_names.append(inst.ch_names[ecg_idx])
ch_types.append("ecg")
extra_picks = np.concatenate((eog_chs, ecg_chs)).astype(int)
extra_picks = pick_types(inst.info, meg=False, ecg=True, eog=True, ref_meg=False)
for idx in extra_picks[::-1]:
ch_names.insert(0, inst.ch_names[idx])
ch_types.insert(0, "eog" if idx in eog_chs else "ecg")
# for eog_idx in eog_chs[::-1]:
# ch_names.insert(0, inst.ch_names[eog_idx])
# ch_types.insert(0, "eog")
# # ch_names.append(inst.ch_names[eog_idx])
# # ch_types.append("eog")
# for ecg_idx in ecg_chs[::-1]:
# ch_names.insert(0, inst.ch_names[ecg_idx])
# ch_types.insert(0, "ecg")
# # ch_names.append(inst.ch_names[ecg_idx])
# # ch_types.append("ecg")
# extra_picks = np.concatenate((eog_chs, ecg_chs)).astype(int)
if len(extra_picks):
if is_raw:
eog_ecg_data, _ = inst[extra_picks, :]
else:
eog_ecg_data = np.concatenate(inst.get_data(extra_picks), axis=1)
data = np.append(data, eog_ecg_data, axis=0)
data = np.append(eog_ecg_data, data, axis=0)
picks = np.concatenate((picks, ica.n_components_ + np.arange(len(extra_picks))))
ch_order = np.arange(len(picks))
n_channels = min([n_channels, len(picks)])
Expand Down Expand Up @@ -312,7 +320,7 @@ def _plot_sources(
)

fig = _get_browser(**params)

fig.mne.ch_start = len(extra_picks) # this is necessary to make sure to plot the EOG/ECG only once
fig._update_picks()

# update data, and plot
Expand Down Expand Up @@ -398,6 +406,7 @@ def __init__(self, inst, figsize, ica=None,
from mpl_toolkits.axes_grid1.axes_size import Fixed

# # OSL IMPORTS
from mne import pick_types
from mne import BaseEpochs
from mne.io import BaseRaw
from mne.preprocessing import ICA
Expand Down Expand Up @@ -425,6 +434,7 @@ def __init__(self, inst, figsize, ica=None,
vscroll_dist = 0.1
help_width = scroll_width * 2
# MVE: ADD SIZES FOR TOPOS
extra_chans = pick_types(inst.info, meg=False, eeg=False, ref_meg=False, eog=True, ecg=True, exclude=[])
exist_meg = any(ct in np.unique(ica.get_channel_types()) for ct in ['mag', 'grad'])
exist_eeg = 'eeg' in np.unique(ica.get_channel_types())
n_topos = len(
Expand All @@ -435,6 +445,7 @@ def __init__(self, inst, figsize, ica=None,
]
)
)
# n_topos -= len(extra_chans)
topo_width_ratio = 8 + n_topos # 1
topo_dist = self._inch_to_rel(0.05) # 0.25

Expand Down Expand Up @@ -675,6 +686,29 @@ def __init__(self, inst, figsize, ica=None,
vline_text=vline_text,
)


def _update_picks(self):
import numpy as np
"""Compute which channel indices to show."""
n_extra_chans = int(np.sum([1 for k, ch_type in enumerate(self.mne.ch_types) if ch_type == 'eog' or ch_type == 'ecg']))
if self.mne.butterfly and self.mne.ch_selections is not None:
selections_dict = self._make_butterfly_selections_dict()
self.mne.picks = np.concatenate(tuple(selections_dict.values()))
elif self.mne.butterfly:
self.mne.picks = self.mne.ch_order
else:
# this is replaced:
# _slice = slice(self.mne.picks[n_extra_chans],
# self.mne.picks[n_extra_chans] + self.mne.n_channels)
# self.mne.picks = self.mne.ch_order[_slice]

_slice = slice(self.mne.ch_start,
self.mne.ch_start + self.mne.n_channels - n_extra_chans )
self.mne.picks = np.concatenate([np.arange(n_extra_chans), self.mne.ch_order[_slice]])
self.mne.n_channels = len(self.mne.picks)
assert isinstance(self.mne.picks, np.ndarray)
assert self.mne.picks.dtype.kind == 'i'

def _draw_traces(self):
"""Draw (or redraw) the channel data."""

Expand All @@ -700,8 +734,10 @@ def _draw_traces(self):
bad_bool = np.in1d(ch_names, self.mne.info["bads"])
# OSL ADDITION
bad_int = list(np.ones(len(picks))*-1)
tmppicks = [picks[k] for k in np.where([j<len(self.mne.ica._ica_names) for j in picks])[0]] # we don't want to do this for the artefact channels (e.g. EOC/ECG)
for cnt, ch in enumerate([self.mne.ica._ica_names[ii] for ii in tmppicks]):
extra_chans = [picks[k] for k, ch_type in enumerate(ch_types) if ch_type == 'eog' or ch_type=='ecg']
for cnt, ch in enumerate([self.mne.ch_names[ii] for ii in picks]):
if cnt < len(extra_chans):
continue
i = self.mne.ica._ica_names.index(ch)
if ch in self.mne.info["bads"]:
if len(list(self.mne.ica.labels_.values())) > 0 and i in np.concatenate(list(self.mne.ica.labels_.values())):
Expand Down Expand Up @@ -898,6 +934,7 @@ def plot_topos(self, ica, ax_topo, picks): # OSL ADDITION FOR TOPOS
import mne
from mne.viz.topomap import _plot_ica_topomap

extra_chans = [k for k, ch_type in enumerate(self.mne.ch_types[picks]) if ch_type == 'eog' or ch_type == 'ecg']
exist_meg = any(ct in np.unique(ica.get_channel_types()) for ct in ['mag', 'grad'])
exist_eeg = 'eeg' in np.unique(ica.get_channel_types())
n_topos = len(picks)
Expand All @@ -913,10 +950,14 @@ def plot_topos(self, ica, ax_topo, picks): # OSL ADDITION FOR TOPOS
n_chtype = len(chtype)
for i in range(n_chtype):
for j in range(n_topos):
if picks[j] < ncomps:
if picks[j]<len(extra_chans):
ax_topo[i, j].clear()
ax_topo[i, j].set_axis_off()
else:

_plot_ica_topomap(
ica_tmp,
idx=picks[j],
idx=picks[j]-len(extra_chans),
ch_type=chtype[i],
axes=ax_topo[i, j],
vmin=None,
Expand All @@ -933,14 +974,8 @@ def plot_topos(self, ica, ax_topo, picks): # OSL ADDITION FOR TOPOS
allow_ref_meg=False,
sphere=None,
)
else:
# We likely have an EOG/ECG comp - don't plot a topo
ax_topo[i, j].clear()
ax_topo[i, j].set_xticks([])
ax_topo[i, j].set_yticks([])

if j==0:
ax_topo[i, j].set_title(f"{chtype[i]}")
ax_topo[i, j].set_title(f"{chtype[i]}")
else:
ax_topo[i, j].set_title('')

Expand All @@ -951,6 +986,7 @@ def _keypress(self, event):
"""Handle keypress events."""
key = event.key
n_channels = self.mne.n_channels
n_extra_chans = int(np.sum([1 for k, ch_type in enumerate(self.mne.ch_types) if ch_type == 'eog' or ch_type == 'ecg']))
if self.mne.is_epochs:
last_time = self.mne.n_times / self.mne.info["sfreq"]
else:
Expand Down Expand Up @@ -982,10 +1018,11 @@ def _keypress(self, event):
buttons.set_active(current_idx + direction)
# normal case
else:
ceiling = len(self.mne.ch_order) - n_channels
ch_start = self.mne.ch_start + direction * n_channels
self.mne.ch_start = np.clip(ch_start, 0, ceiling)
self._update_picks()
ceiling = len(self.mne.ch_order)
ch_start = self.mne.picks[n_extra_chans] + direction * (n_channels - n_extra_chans)
# ch_start = np.clip(self.mne.ch_start, n_extra_chans, ceiling) + direction * (n_channels - n_extra_chans)
self.mne.ch_start = np.clip(ch_start, n_extra_chans, ceiling)
self._update_picks()
self._update_vscroll()
self._redraw()
# scroll left/right
Expand Down Expand Up @@ -1144,38 +1181,46 @@ def _close(self, event):
# OSL ADDITION: remove bad component labels that were reversed to good component
tmp = list(self.mne.ica.labels_.values())[:]
try:
tmp = np.concatenate(tmp)
tmp = np.unique(np.concatenate(tmp))
except:
tmp = []

for ch in tmp:
ch = int(ch)
if ch not in self.mne.ica.exclude:
# find in which label it has
allix = np.where(list(self.mne.ica.labels_.values()) == ch)
allix = np.where([ch in self.mne.ica.labels_[key] for key in self.mne.ica.labels_.keys()])[0]
for ix in allix:
self.mne.ica.labels_[list(self.mne.ica.labels_.keys())[ix]] = \
np.setdiff1d(self.mne.ica.labels_[list(self.mne.ica.labels_.keys())[ix]], ch)

# label bad components without a manual label as "unknown"
for ch in self.mne.ica.exclude:
ch = int(ch)
tmp = list(self.mne.ica.labels_.values())
if len(tmp)==0:
tmp = []
else:
tmp = np.concatenate(tmp)
if ch not in tmp:
if "unknown" not in self.mne.ica.labels_:
if "unknown" not in self.mne.ica.labels_.keys():
self.mne.ica.labels_["unknown"] = []
self.mne.ica.labels_["unknown"] = list(self.mne.ica.labels_["unknown"])
self.mne.ica.labels_["unknown"].append(ch)
if type(self.mne.ica.labels_["unknown"]) is np.ndarray:
self.mne.ica.labels_["unknown"] = self.mne.ica.labels_["unknown"].tolist()



# Add to labels_ a generic eog/ecg field
if len(list(self.mne.ica.labels_.keys())) > 0:
if "ecg" not in self.mne.ica.labels_:
self.mne.ica.labels_["ecg"] = []
if "eog" not in self.mne.ica.labels_:
self.mne.ica.labels_["eog"] = []
for key in self.mne.ica.labels_.keys():
self.mne.ica.labels_[key] = list(self.mne.ica.labels_[key])

for key in self.mne.ica.labels_.keys():
self.mne.ica.labels_[key] = list(self.mne.ica.labels_[key])

for k in list(self.mne.ica.labels_.keys()):
if "ecg" in k.lower() and k.lower() != "ecg":
tmp = self.mne.ica.labels_[k]
Expand All @@ -1191,7 +1236,9 @@ def _close(self, event):
self.mne.ica.labels_["eog"] = [v for v in self.mne.ica.labels_["eog"] if v!= []]
self.mne.ica.labels_["ecg"] = np.unique(self.mne.ica.labels_["ecg"]).tolist()
self.mne.ica.labels_["eog"] = np.unique(self.mne.ica.labels_["eog"]).tolist()

for key in self.mne.ica.labels_.keys():
self.mne.ica.labels_[key] = list(self.mne.ica.labels_[key])

# write logs
logger.info(f"Components marked as bad: {sorted(self.mne.ica.exclude) or 'none'}")
for lb in self.mne.ica.labels_.keys():
Expand Down

0 comments on commit 6446af3

Please sign in to comment.