From c6035f4709695776c9a6443c28e34f8f19a9ddf2 Mon Sep 17 00:00:00 2001 From: matsvanes Date: Fri, 6 Sep 2024 12:42:32 +0100 Subject: [PATCH] fix vscroll and adapt topo layout --- osl/preprocessing/ica_label.py | 47 +++++++++++++++++---------------- osl/preprocessing/plot_ica.py | 48 ++++++++++++---------------------- 2 files changed, 40 insertions(+), 55 deletions(-) diff --git a/osl/preprocessing/ica_label.py b/osl/preprocessing/ica_label.py index 5778781..132cab7 100644 --- a/osl/preprocessing/ica_label.py +++ b/osl/preprocessing/ica_label.py @@ -21,7 +21,7 @@ from osl.preprocessing.plot_ica import plot_ica from osl.report import plot_bad_ica from osl.report.preproc_report import gen_html_page, gen_html_summary -from ..utils import logger as osl_logger +from osl.utils import logger as osl_logger logger = logging.getLogger(__name__) @@ -195,31 +195,32 @@ def main(argv=None): """ - if argv is None: - argv = sys.argv[1:] + # if argv is None: + # argv = sys.argv[1:] - reject = argv[0] - if reject == 'None': - reject = None + # reject = argv[0] + # if reject == 'None': + # reject = None - if len(argv)<3: - data_dir = os.getcwd() - if len(argv)==2: - subject = argv[1] - else: - g = sorted(glob(os.path.join(f"{data_dir}", '*', '*_ica.fif'))) - subject = [f.split('/')[-2] for f in g] - # batch log - logs_dir = os.path.join(data_dir, 'logs') - logfile = os.path.join(logs_dir, 'osl_batch.log') - osl_logger.set_up(log_file=logfile, level="INFO", startup=False) - logger.info('Starting OSL-ICA Batch Processing') - logger.info('Running osl_ica_label on {0} subjects with reject={1}'.format(len(subject), str(reject))) - else: - data_dir = argv[1] - subject = argv[2] + # if len(argv)<3: + # data_dir = os.getcwd() + # if len(argv)==2: + # subject = argv[1] + # else: + # g = sorted(glob(os.path.join(f"{data_dir}", '*', '*_ica.fif'))) + # subject = [f.split('/')[-2] for f in g] + # # batch log + # logs_dir = os.path.join(data_dir, 'logs') + # logfile = os.path.join(logs_dir, 'osl_batch.log') + # osl_logger.set_up(log_file=logfile, level="INFO", startup=False) + # logger.info('Starting OSL-ICA Batch Processing') + # logger.info('Running osl_ica_label on {0} subjects with reject={1}'.format(len(subject), str(reject))) + # else: + # data_dir = argv[1] + # subject = argv[2] - ica_label(data_dir=data_dir, subject=subject, reject=reject) + # ica_label(data_dir=data_dir, subject=subject, reject=reject) + ica_label(data_dir='/ohba/pi/mwoolrich/osl-dev/meguk-debugging/output_mve/', subject='sub-oxf001_task-resteyesclosed', reject=None) def apply(argv=None): diff --git a/osl/preprocessing/plot_ica.py b/osl/preprocessing/plot_ica.py index 0b4018e..9f18076 100644 --- a/osl/preprocessing/plot_ica.py +++ b/osl/preprocessing/plot_ica.py @@ -4,6 +4,7 @@ # Authors: Mats van Es import logging +import numpy as np import matplotlib.pyplot as plt # Configure logging @@ -151,7 +152,6 @@ def _plot_sources( from mne.io.meas_info import create_info from mne.io.pick import pick_types from mne.defaults import _handle_default - import numpy as np # handle defaults / check arg validity is_raw = isinstance(inst, BaseRaw) @@ -186,22 +186,10 @@ def _plot_sources( # add EOG/ECG channels if present eog_chs = pick_types(inst.info, meg=False, eog=True, ref_meg=False) - ecg_chs = pick_types(inst.info, meg=False, ecg=True, ref_meg=False) extra_picks = pick_types(inst.info, meg=False, ecg=True, eog=True, ref_meg=False) for idx in extra_picks[::-1]: ch_names.insert(0, inst.ch_names[idx]) ch_types.insert(0, "eog" if idx in eog_chs else "ecg") - # for eog_idx in eog_chs[::-1]: - # ch_names.insert(0, inst.ch_names[eog_idx]) - # ch_types.insert(0, "eog") - # # ch_names.append(inst.ch_names[eog_idx]) - # # ch_types.append("eog") - # for ecg_idx in ecg_chs[::-1]: - # ch_names.insert(0, inst.ch_names[ecg_idx]) - # ch_types.insert(0, "ecg") - # # ch_names.append(inst.ch_names[ecg_idx]) - # # ch_types.append("ecg") - # extra_picks = np.concatenate((eog_chs, ecg_chs)).astype(int) if len(extra_picks): if is_raw: eog_ecg_data, _ = inst[extra_picks, :] @@ -350,7 +338,6 @@ def _get_browser(**kwargs): """ from mne.viz.utils import _get_figsize_from_config from mne.viz._figure import _init_browser_backend - import numpy as np figsize = kwargs.setdefault("figsize", _get_figsize_from_config()) if figsize is None or np.any(np.array(figsize) < 8): @@ -379,7 +366,6 @@ def _init_browser(backend, **kwargs): # OSL ADDITION IN ORDER TO USE OSL'S FIGU fig.canvas.draw() fig._update_zen_mode_offsets() fig._resize(None) # needed for MPL >=3.4 - # if scrollbars are supposed to start hidden, # set to True and then toggle if not fig.mne.scrollbars_visible: @@ -412,7 +398,6 @@ def __init__(self, inst, figsize, ica=None, from mne.preprocessing import ICA from mne.viz._figure import BrowserBase from mne.viz._mpl_figure import MNEFigure, _patched_canvas - import numpy as np import mne from functools import partial @@ -445,7 +430,6 @@ def __init__(self, inst, figsize, ica=None, ] ) ) - # n_topos -= len(extra_chans) topo_width_ratio = 8 + n_topos # 1 topo_dist = self._inch_to_rel(0.05) # 0.25 @@ -486,9 +470,8 @@ def __init__(self, inst, figsize, ica=None, topo_position = [ left + i * (topo_width + topo_dist), bottom - + self._inch_to_rel(hscroll_dist + b_margin) - + ((self.mne.n_channels - 1) - j) * (topo_height + topo_dist) - + topo_dist, + + ((self.mne.n_channels) - j) * (topo_height + topo_dist)*1.03 + - self._inch_to_rel(0.13), topo_width, topo_height, ] @@ -566,7 +549,7 @@ def __init__(self, inst, figsize, ica=None, # VERTICAL SCROLLBAR PATCHES (COLORED BY CHANNEL TYPE) ch_order = self.mne.ch_order - for ix, pick in enumerate(ch_order): + for ix, pick in enumerate(ch_order[len(extra_chans):]): this_color = ( self.mne.ch_color_bad if self.mne.ch_names[pick] in self.mne.info["bads"] @@ -579,14 +562,14 @@ def __init__(self, inst, figsize, ica=None, (0, ix), 1, 1, color=this_color, zorder=self.mne.zorder["patch"] ) ) - ax_vscroll.set_ylim(len(ch_order), 0) + ax_vscroll.set_ylim(len(ch_order) - len(extra_chans), 0) ax_vscroll.set_visible(not self.mne.butterfly) # SCROLLBAR VISIBLE SELECTION PATCHES sel_kwargs = dict( alpha=0.3, linewidth=4, clip_on=False, edgecolor=self.mne.fgcolor ) vsel_patch = Rectangle( - (0, 0), 1, self.mne.n_channels, facecolor=self.mne.bgcolor, **sel_kwargs + (0, 0), 1, self.mne.n_channels - len(extra_chans), facecolor=self.mne.bgcolor, **sel_kwargs ) ax_vscroll.add_patch(vsel_patch) hsel_facecolor = np.average( @@ -688,7 +671,6 @@ def __init__(self, inst, figsize, ica=None, def _update_picks(self): - import numpy as np """Compute which channel indices to show.""" n_extra_chans = int(np.sum([1 for k, ch_type in enumerate(self.mne.ch_types) if ch_type == 'eog' or ch_type == 'ecg'])) if self.mne.butterfly and self.mne.ch_selections is not None: @@ -718,7 +700,6 @@ def _draw_traces(self): # OSL ADDITION from mne import pick_types from mne.io.pick import channel_type - import numpy as np # clear scalebars if self.mne.scalebars_visible: @@ -928,9 +909,9 @@ def _draw_traces(self): plt.figtext(self.mne.bad_labels_xpos, self.mne.bad_labels_ypos[i], f'{i-1}: ' + self.mne.bad_labels_list[i - 2], color=self.mne.bad_label_colors[i - 2], fontweight='semibold') + self._update_vscroll() # takes care of the vsel_patch, because it's too big when there's extra chans def plot_topos(self, ica, ax_topo, picks): # OSL ADDITION FOR TOPOS - import numpy as np import mne from mne.viz.topomap import _plot_ica_topomap @@ -981,7 +962,6 @@ def plot_topos(self, ica, ax_topo, picks): # OSL ADDITION FOR TOPOS def _keypress(self, event): from mne.viz.utils import _events_off - import numpy as np """Handle keypress events.""" key = event.key @@ -1018,9 +998,8 @@ def _keypress(self, event): buttons.set_active(current_idx + direction) # normal case else: - ceiling = len(self.mne.ch_order) - ch_start = self.mne.picks[n_extra_chans] + direction * (n_channels - n_extra_chans) - # ch_start = np.clip(self.mne.ch_start, n_extra_chans, ceiling) + direction * (n_channels - n_extra_chans) + ceiling = len(self.mne.ch_order) - (n_channels - n_extra_chans) + ch_start = self.mne.ch_start + direction * (n_channels - n_extra_chans) self.mne.ch_start = np.clip(ch_start, n_extra_chans, ceiling) self._update_picks() self._update_vscroll() @@ -1149,13 +1128,18 @@ def _keypress(self, event): else: # check for close key / fullscreen toggle super()._keypress(event) - + def _update_vscroll(self): + """Update the vertical scrollbar (channel) selection indicator.""" + n_extra_chans = int(np.sum([1 for k, ch_type in enumerate(self.mne.ch_types) if ch_type == 'eog' or ch_type == 'ecg'])) + self.mne.vsel_patch.set_xy((0, self.mne.ch_start - n_extra_chans)) + self.mne.vsel_patch.set_height(self.mne.n_channels - n_extra_chans) + self._update_yaxis_labels() + def _close(self, event): # OSL VERSION - SIMILAR TO OLD MNE VERSION TODO: Check if we need to adopt this """Handle close events (via keypress or window [x]).""" from matplotlib.pyplot import close from mne.utils import set_config - import numpy as np # write out bad epochs (after converting epoch numbers to indices) if self.mne.instance_type == "epochs":