diff --git a/tedana/decomposition/ica.py b/tedana/decomposition/ica.py index d687db6e0..40db7ce98 100644 --- a/tedana/decomposition/ica.py +++ b/tedana/decomposition/ica.py @@ -4,8 +4,9 @@ import warnings import numpy as np -from robustica import RobustICA +from robustica import RobustICA, abs_pearson_dist from scipy import stats +from sklearn import manifold from sklearn.decomposition import FastICA from tedana.config import ( @@ -69,7 +70,7 @@ def tedica( ica_method = ica_method.lower() if ica_method == "robustica": - mixing, fixed_seed = r_ica( + mixing, fixed_seed, c_labels, similarity_t_sne = r_ica( data, n_components=n_components, fixed_seed=fixed_seed, @@ -87,7 +88,7 @@ def tedica( else: raise ValueError("The selected ICA method is invalid!") - return mixing, fixed_seed + return mixing, fixed_seed, c_labels, similarity_t_sne def r_ica(data, n_components, fixed_seed, n_robust_runs, max_it): @@ -192,7 +193,23 @@ def r_ica(data, n_components, fixed_seed, n_robust_runs, max_it): f"decomposition." ) - return mixing, fixed_seed + c_labels = robust_ica.clustering.labels_ + + perplexity = min(robust_ica.S_all.shape[1] - 1, 80) + + perplexity = perplexity - 1 if perplexity < 81 else 80 + t_sne = manifold.TSNE( + n_components=2, + perplexity=perplexity, + init="random", + n_iter=2500, + random_state=10, + ) + + p_dissimilarity = abs_pearson_dist(robust_ica.S_all) + similarity_t_sne = t_sne.fit_transform(p_dissimilarity) + + return mixing, fixed_seed, c_labels, similarity_t_sne def f_ica(data, n_components, fixed_seed, maxit, maxrestart): diff --git a/tedana/reporting/data/html/report_body_template.html b/tedana/reporting/data/html/report_body_template.html index 6cf812ba3..f0b87b625 100644 --- a/tedana/reporting/data/html/report_body_template.html +++ b/tedana/reporting/data/html/report_body_template.html @@ -39,9 +39,6 @@ .carpet-plots { float: left; - } - - .carpet-plots { margin-left: 5%; margin-right: 5%; margin-bottom: 100px; @@ -64,6 +61,13 @@ float: left; } + .tsne-plots { + float: left; + margin-left: 5%; + margin-right: 5%; + margin-bottom: 100px; + } + button { margin-right: 15px; width: auto; @@ -209,6 +213,7 @@

T2* and S0 model fit (RMSE). (Scaled between 2nd and 98th percentiles)

+$tsne

Info

$info diff --git a/tedana/reporting/dynamic_figures.py b/tedana/reporting/dynamic_figures.py index 99c1a6fae..e6cc19cc6 100644 --- a/tedana/reporting/dynamic_figures.py +++ b/tedana/reporting/dynamic_figures.py @@ -421,3 +421,120 @@ def _link_figures(fig, comptable_ds, div_content, io_generator): """ fig.js_on_event(events.Tap, _tap_callback(comptable_ds, div_content, io_generator)) return fig + + +def _create_clustering_tsne_plt(cluster_labels, similarity_t_sne): + """Plot the clustering results of robustica using Bokeh. + + Parameters + ---------- + cluster_labels : (n_pca_components x n_robust_runs,) : numpy.ndarray + A one dimensional array that has the cluster label of each run. + similarity_t_sne : (n_pca_components x n_robust_runs,2) : numpy.ndarray + An array containing the coordinates of projected data. + """ + title = "2D projection of clustered ICA runs using TSNE" + marker_size = 8 + alpha = 0.8 + line_width = 2 + + # First create the figure without the hover tool + p = plotting.figure( + title=title, + width=800, + height=600, + tools=["pan", "box_zoom", "wheel_zoom", "reset", "save"], # No hover tool here + ) + + point_renderers = [] # List to store point renderers + + # Plot regular clusters + for cluster_id in range(np.max(cluster_labels) + 1): + cluster_mask = cluster_labels == cluster_id + if not np.any(cluster_mask): + continue + + # Get points for this cluster + cluster_points = similarity_t_sne[cluster_mask] + + # Add scatter plot for cluster points with hover info + circle_renderer = p.circle( + x="x", + y="y", + source=models.ColumnDataSource( + { + "x": cluster_points[:, 0], + "y": cluster_points[:, 1], + "cluster": [f"Cluster {cluster_id}"] * len(cluster_points), + } + ), + size=marker_size, + alpha=alpha, + line_color="black", + fill_color=None, + line_width=line_width, + legend_label="Clustered runs", + name="points", + ) + point_renderers.append(circle_renderer) + + # Add hull if enough points + if cluster_points.shape[0] > 2: + from scipy.spatial import ConvexHull + + hull = ConvexHull(cluster_points) + centroid = np.mean(cluster_points[hull.vertices], axis=0) + scaled_points = centroid + 1.5 * (cluster_points - centroid) + + # Create hull line segments + xs = [] + ys = [] + for simplex in hull.simplices: + xs.extend([scaled_points[simplex[0], 0], scaled_points[simplex[1], 0], None]) + ys.extend([scaled_points[simplex[0], 1], scaled_points[simplex[1], 1], None]) + + # Add line without hover tooltips + p.line( + x=xs, + y=ys, + line_color="blue", + line_dash="dashed", + line_width=line_width, + legend_label="Cluster's boundary", + ) + + # Plot noise clusters if they exist + if np.min(cluster_labels) == -1: + noise_mask = cluster_labels == -1 + noise_points = similarity_t_sne[noise_mask] + + # Add noise points with hover tooltips + x_renderer = p.x( + x="x", + y="y", + size=marker_size * 2, + alpha=0.6, + color="red", + legend_label="Unclustered runs", + source=models.ColumnDataSource( + { + "x": noise_points[:, 0], + "y": noise_points[:, 1], + "cluster": ["Unclustered"] * len(noise_points), + } + ), + ) + point_renderers.append(x_renderer) + + # Add hover tool after creating all renderers, specifically for points + hover_tool = models.HoverTool( + tooltips=[("Cluster", "@cluster")], + renderers=point_renderers, # Only apply to stored point renderers + ) + p.add_tools(hover_tool) + + # Configure legend + p.legend.click_policy = "hide" + p.legend.location = "top_right" + + return p diff --git a/tedana/reporting/html_report.py b/tedana/reporting/html_report.py index b07cd8fbe..0ed20681d 100644 --- a/tedana/reporting/html_report.py +++ b/tedana/reporting/html_report.py @@ -113,7 +113,9 @@ def _generate_buttons(out_dir, io_generator): return buttons_html -def _update_template_bokeh(bokeh_id, info_table, about, prefix, references, bokeh_js, buttons): +def _update_template_bokeh( + bokeh_id, info_table, about, prefix, references, bokeh_js, buttons, tsne +): """ Populate a report with content. @@ -133,6 +135,8 @@ def _update_template_bokeh(bokeh_id, info_table, about, prefix, references, boke Javascript created by bokeh.embed.components buttons : str HTML div created by _generate_buttons() + tsne : str + HTML div created by _create_clustering_tsne_plt() Returns ------- @@ -181,6 +185,7 @@ def _update_template_bokeh(bokeh_id, info_table, about, prefix, references, boke references=references, javascript=bokeh_js, buttons=buttons, + tsne=tsne, ) return body @@ -231,7 +236,7 @@ def _generate_info_table(info_dict): return info_html -def generate_report(io_generator: OutputGenerator) -> None: +def generate_report(io_generator: OutputGenerator, cluster_labels, similarity_t_sne) -> None: """Generate an HTML report. Parameters @@ -320,6 +325,15 @@ def get_elbow_val(elbow_prefix): ) varexp_pie_plot = df._create_varexp_pie_plt(comptable_cds) + # Create clustering plot + if cluster_labels is not None: + clustering_tsne_plot = df._create_clustering_tsne_plt(cluster_labels, similarity_t_sne) + tsne_script, tsne_div = embed.components(clustering_tsne_plot) + tsne_html = f"{tsne_script}" + tsne_html += ( + f"

Robust ICA component clustering

{tsne_div}
" + ) + # link all dynamic figures figs = [kappa_rho_plot, kappa_sorted_plot, rho_sorted_plot, varexp_pie_plot] @@ -371,6 +385,7 @@ def get_elbow_val(elbow_prefix): prefix=io_generator.prefix, bokeh_js=kr_script, buttons=buttons_html, + tsne=tsne_html, ) html = _save_as_html(body) with open(opj(io_generator.out_dir, f"{io_generator.prefix}tedana_report.html"), "wb") as f: diff --git a/tedana/workflows/tedana.py b/tedana/workflows/tedana.py index 1f54361e6..2f6efe52e 100644 --- a/tedana/workflows/tedana.py +++ b/tedana/workflows/tedana.py @@ -789,7 +789,7 @@ def tedana_workflow( n_restarts = 0 seed = fixed_seed while keep_restarting: - mixing, seed = decomposition.tedica( + mixing, seed, cluster_labels, similarity_t_sne = decomposition.tedica( data_reduced, n_components, seed, @@ -1070,7 +1070,7 @@ def tedana_workflow( ) LGR.info("Generating dynamic report") - reporting.generate_report(io_generator) + reporting.generate_report(io_generator, cluster_labels, similarity_t_sne) LGR.info("Workflow completed") utils.teardown_loggers()