Skip to content

Commit

Permalink
Merge pull request #341 from OHBA-analysis/ica_fix
Browse files Browse the repository at this point in the history
fix vscroll and adapt topo layout
  • Loading branch information
matsvanes authored Sep 6, 2024
2 parents ae26c31 + c6035f4 commit cd3b54e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 55 deletions.
47 changes: 24 additions & 23 deletions osl/preprocessing/ica_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from osl.preprocessing.plot_ica import plot_ica
from osl.report import plot_bad_ica
from osl.report.preproc_report import gen_html_page, gen_html_summary
from ..utils import logger as osl_logger
from osl.utils import logger as osl_logger

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -195,31 +195,32 @@ def main(argv=None):
"""

if argv is None:
argv = sys.argv[1:]
# if argv is None:
# argv = sys.argv[1:]

reject = argv[0]
if reject == 'None':
reject = None
# reject = argv[0]
# if reject == 'None':
# reject = None

if len(argv)<3:
data_dir = os.getcwd()
if len(argv)==2:
subject = argv[1]
else:
g = sorted(glob(os.path.join(f"{data_dir}", '*', '*_ica.fif')))
subject = [f.split('/')[-2] for f in g]
# batch log
logs_dir = os.path.join(data_dir, 'logs')
logfile = os.path.join(logs_dir, 'osl_batch.log')
osl_logger.set_up(log_file=logfile, level="INFO", startup=False)
logger.info('Starting OSL-ICA Batch Processing')
logger.info('Running osl_ica_label on {0} subjects with reject={1}'.format(len(subject), str(reject)))
else:
data_dir = argv[1]
subject = argv[2]
# if len(argv)<3:
# data_dir = os.getcwd()
# if len(argv)==2:
# subject = argv[1]
# else:
# g = sorted(glob(os.path.join(f"{data_dir}", '*', '*_ica.fif')))
# subject = [f.split('/')[-2] for f in g]
# # batch log
# logs_dir = os.path.join(data_dir, 'logs')
# logfile = os.path.join(logs_dir, 'osl_batch.log')
# osl_logger.set_up(log_file=logfile, level="INFO", startup=False)
# logger.info('Starting OSL-ICA Batch Processing')
# logger.info('Running osl_ica_label on {0} subjects with reject={1}'.format(len(subject), str(reject)))
# else:
# data_dir = argv[1]
# subject = argv[2]

ica_label(data_dir=data_dir, subject=subject, reject=reject)
# ica_label(data_dir=data_dir, subject=subject, reject=reject)
ica_label(data_dir='/ohba/pi/mwoolrich/osl-dev/meguk-debugging/output_mve/', subject='sub-oxf001_task-resteyesclosed', reject=None)


def apply(argv=None):
Expand Down
48 changes: 16 additions & 32 deletions osl/preprocessing/plot_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# Authors: Mats van Es <mats.vanes@psych.ox.ac.uk>
import logging
import numpy as np
import matplotlib.pyplot as plt

# Configure logging
Expand Down Expand Up @@ -151,7 +152,6 @@ def _plot_sources(
from mne.io.meas_info import create_info
from mne.io.pick import pick_types
from mne.defaults import _handle_default
import numpy as np

# handle defaults / check arg validity
is_raw = isinstance(inst, BaseRaw)
Expand Down Expand Up @@ -186,22 +186,10 @@ def _plot_sources(

# add EOG/ECG channels if present
eog_chs = pick_types(inst.info, meg=False, eog=True, ref_meg=False)
ecg_chs = pick_types(inst.info, meg=False, ecg=True, ref_meg=False)
extra_picks = pick_types(inst.info, meg=False, ecg=True, eog=True, ref_meg=False)
for idx in extra_picks[::-1]:
ch_names.insert(0, inst.ch_names[idx])
ch_types.insert(0, "eog" if idx in eog_chs else "ecg")
# for eog_idx in eog_chs[::-1]:
# ch_names.insert(0, inst.ch_names[eog_idx])
# ch_types.insert(0, "eog")
# # ch_names.append(inst.ch_names[eog_idx])
# # ch_types.append("eog")
# for ecg_idx in ecg_chs[::-1]:
# ch_names.insert(0, inst.ch_names[ecg_idx])
# ch_types.insert(0, "ecg")
# # ch_names.append(inst.ch_names[ecg_idx])
# # ch_types.append("ecg")
# extra_picks = np.concatenate((eog_chs, ecg_chs)).astype(int)
if len(extra_picks):
if is_raw:
eog_ecg_data, _ = inst[extra_picks, :]
Expand Down Expand Up @@ -350,7 +338,6 @@ def _get_browser(**kwargs):
"""
from mne.viz.utils import _get_figsize_from_config
from mne.viz._figure import _init_browser_backend
import numpy as np

figsize = kwargs.setdefault("figsize", _get_figsize_from_config())
if figsize is None or np.any(np.array(figsize) < 8):
Expand Down Expand Up @@ -379,7 +366,6 @@ def _init_browser(backend, **kwargs): # OSL ADDITION IN ORDER TO USE OSL'S FIGU
fig.canvas.draw()
fig._update_zen_mode_offsets()
fig._resize(None) # needed for MPL >=3.4

# if scrollbars are supposed to start hidden,
# set to True and then toggle
if not fig.mne.scrollbars_visible:
Expand Down Expand Up @@ -412,7 +398,6 @@ def __init__(self, inst, figsize, ica=None,
from mne.preprocessing import ICA
from mne.viz._figure import BrowserBase
from mne.viz._mpl_figure import MNEFigure, _patched_canvas
import numpy as np
import mne
from functools import partial

Expand Down Expand Up @@ -445,7 +430,6 @@ def __init__(self, inst, figsize, ica=None,
]
)
)
# n_topos -= len(extra_chans)
topo_width_ratio = 8 + n_topos # 1
topo_dist = self._inch_to_rel(0.05) # 0.25

Expand Down Expand Up @@ -486,9 +470,8 @@ def __init__(self, inst, figsize, ica=None,
topo_position = [
left + i * (topo_width + topo_dist),
bottom
+ self._inch_to_rel(hscroll_dist + b_margin)
+ ((self.mne.n_channels - 1) - j) * (topo_height + topo_dist)
+ topo_dist,
+ ((self.mne.n_channels) - j) * (topo_height + topo_dist)*1.03
- self._inch_to_rel(0.13),
topo_width,
topo_height,
]
Expand Down Expand Up @@ -566,7 +549,7 @@ def __init__(self, inst, figsize, ica=None,

# VERTICAL SCROLLBAR PATCHES (COLORED BY CHANNEL TYPE)
ch_order = self.mne.ch_order
for ix, pick in enumerate(ch_order):
for ix, pick in enumerate(ch_order[len(extra_chans):]):
this_color = (
self.mne.ch_color_bad
if self.mne.ch_names[pick] in self.mne.info["bads"]
Expand All @@ -579,14 +562,14 @@ def __init__(self, inst, figsize, ica=None,
(0, ix), 1, 1, color=this_color, zorder=self.mne.zorder["patch"]
)
)
ax_vscroll.set_ylim(len(ch_order), 0)
ax_vscroll.set_ylim(len(ch_order) - len(extra_chans), 0)
ax_vscroll.set_visible(not self.mne.butterfly)
# SCROLLBAR VISIBLE SELECTION PATCHES
sel_kwargs = dict(
alpha=0.3, linewidth=4, clip_on=False, edgecolor=self.mne.fgcolor
)
vsel_patch = Rectangle(
(0, 0), 1, self.mne.n_channels, facecolor=self.mne.bgcolor, **sel_kwargs
(0, 0), 1, self.mne.n_channels - len(extra_chans), facecolor=self.mne.bgcolor, **sel_kwargs
)
ax_vscroll.add_patch(vsel_patch)
hsel_facecolor = np.average(
Expand Down Expand Up @@ -688,7 +671,6 @@ def __init__(self, inst, figsize, ica=None,


def _update_picks(self):
import numpy as np
"""Compute which channel indices to show."""
n_extra_chans = int(np.sum([1 for k, ch_type in enumerate(self.mne.ch_types) if ch_type == 'eog' or ch_type == 'ecg']))
if self.mne.butterfly and self.mne.ch_selections is not None:
Expand Down Expand Up @@ -718,7 +700,6 @@ def _draw_traces(self):
# OSL ADDITION
from mne import pick_types
from mne.io.pick import channel_type
import numpy as np

# clear scalebars
if self.mne.scalebars_visible:
Expand Down Expand Up @@ -928,9 +909,9 @@ def _draw_traces(self):
plt.figtext(self.mne.bad_labels_xpos, self.mne.bad_labels_ypos[i], f'{i-1}: ' + self.mne.bad_labels_list[i - 2],
color=self.mne.bad_label_colors[i - 2], fontweight='semibold')

self._update_vscroll() # takes care of the vsel_patch, because it's too big when there's extra chans

def plot_topos(self, ica, ax_topo, picks): # OSL ADDITION FOR TOPOS
import numpy as np
import mne
from mne.viz.topomap import _plot_ica_topomap

Expand Down Expand Up @@ -981,7 +962,6 @@ def plot_topos(self, ica, ax_topo, picks): # OSL ADDITION FOR TOPOS

def _keypress(self, event):
from mne.viz.utils import _events_off
import numpy as np

"""Handle keypress events."""
key = event.key
Expand Down Expand Up @@ -1018,9 +998,8 @@ def _keypress(self, event):
buttons.set_active(current_idx + direction)
# normal case
else:
ceiling = len(self.mne.ch_order)
ch_start = self.mne.picks[n_extra_chans] + direction * (n_channels - n_extra_chans)
# ch_start = np.clip(self.mne.ch_start, n_extra_chans, ceiling) + direction * (n_channels - n_extra_chans)
ceiling = len(self.mne.ch_order) - (n_channels - n_extra_chans)
ch_start = self.mne.ch_start + direction * (n_channels - n_extra_chans)
self.mne.ch_start = np.clip(ch_start, n_extra_chans, ceiling)
self._update_picks()
self._update_vscroll()
Expand Down Expand Up @@ -1149,13 +1128,18 @@ def _keypress(self, event):
else: # check for close key / fullscreen toggle
super()._keypress(event)


def _update_vscroll(self):
"""Update the vertical scrollbar (channel) selection indicator."""
n_extra_chans = int(np.sum([1 for k, ch_type in enumerate(self.mne.ch_types) if ch_type == 'eog' or ch_type == 'ecg']))
self.mne.vsel_patch.set_xy((0, self.mne.ch_start - n_extra_chans))
self.mne.vsel_patch.set_height(self.mne.n_channels - n_extra_chans)
self._update_yaxis_labels()

def _close(self, event):
# OSL VERSION - SIMILAR TO OLD MNE VERSION TODO: Check if we need to adopt this
"""Handle close events (via keypress or window [x])."""
from matplotlib.pyplot import close
from mne.utils import set_config
import numpy as np

# write out bad epochs (after converting epoch numbers to indices)
if self.mne.instance_type == "epochs":
Expand Down

0 comments on commit cd3b54e

Please sign in to comment.