diff --git a/petab_select/plot.py b/petab_select/plot.py index 31ae4e0b..0215ed13 100644 --- a/petab_select/plot.py +++ b/petab_select/plot.py @@ -1,6 +1,8 @@ """Visualization routines for model selection with pyPESTO.""" from typing import Any, Dict, List, Optional, Tuple, Union +import matplotlib.cm +import matplotlib.colors import matplotlib.pyplot as plt import networkx as nx import numpy as np @@ -432,7 +434,10 @@ def graph_iteration_layers( relative: bool = True, ax: plt.Axes = None, draw_networkx_kwargs: Optional[Dict[str, Any]] = None, - colors: Dict[str, str] = None, + # colors: Dict[str, str] = None, + parameter_labels: Dict[str, str] = None, + augment_labels: bool = True, + use_tex: bool = True, ) -> plt.Axes: """Graph the models of each iteration of model selection. @@ -453,49 +458,42 @@ def graph_iteration_layers( The axis to use for plotting. draw_networkx_kwargs: Passed to the `networkx.draw_networkx` call. + parameter_labels: + A dictionary of parameter labels, where keys are parameter IDs, and + values are parameter labels, for plotting. Defaults to parameter IDs. + augment_labels: + If ``True'', provided labels will have estimated parameters and + relative criterion values added to them, for plotting. Returns ------- matplotlib.pyplot.Axes The plot axis. """ + if use_tex: + rcParams0 = dict(plt.rcParams) + rcParams = { + 'text.usetex': True, + #'text.latex': { + # 'preamble': r'\usepackage{color,xcolor}', + # }, + "pgf.rcfonts": False, + "pgf.preamble": r'\usepackage{color}', + } + plt.rcParams.update(rcParams) + # for rcParam, rcParamValues in rcParams.items(): + # plt.rc(rcParam, **rcParamValues) + if ax is None: _, ax = plt.subplots(figsize=(20, 10)) - if labels is None: - labels = { - model.get_hash(): model.model_id - + ( - f'\n{model.get_criterion(criterion):.2f}' - if criterion is not None - else '' - ) - for model in models - } - labels[VIRTUAL_INITIAL_MODEL] = labels.get( - VIRTUAL_INITIAL_MODEL, "Virtual\nInitial\nModel" - ) - - missing_labels = { - model.get_hash(): model.model_id - for model in models - if model.get_hash() not in labels - } - missing_labels2 = { - model.predecessor_model_hash: model.predecessor_model_hash - for model in models - if model.predecessor_model_hash not in labels - } - labels.update(missing_labels2) - labels.update(missing_labels) - # for label in missing_labels: - # labels[label] = label + model_hashes = {model.get_hash(): model for model in models} default_draw_networkx_kwargs = { - 'node_color': NORMAL_NODE_COLOR, + #'node_color': NORMAL_NODE_COLOR, 'arrowstyle': '-|>', 'node_shape': 's', - 'node_size': 2500, + 'node_size': 250, } if draw_networkx_kwargs is None: draw_networkx_kwargs = default_draw_networkx_kwargs @@ -506,6 +504,115 @@ def graph_iteration_layers( ancestry_as_set = {k: set([v]) for k, v in ancestry.items()} ordering = [list(hashes) for hashes in toposort(ancestry_as_set)] + model_estimated_parameters = { + model.get_hash(): set(model.estimated_parameters) for model in models + } + model_criterion_values = None + if criterion is not None: + model_criterion_values = { + model.get_hash(): model.get_criterion(criterion) + for model in models + } + + min_criterion_value = min(model_criterion_values.values()) + model_criterion_values = { + k: v - min_criterion_value + for k, v in model_criterion_values.items() + } + + model_parameter_diffs = { + model.get_hash(): ( + (set(), set()) + if model.predecessor_model_hash not in model_estimated_parameters + else ( + model_estimated_parameters[model.get_hash()].difference( + model_estimated_parameters[model.predecessor_model_hash] + ), + model_estimated_parameters[ + model.predecessor_model_hash + ].difference(model_estimated_parameters[model.get_hash()]), + ) + ) + for model in models + } + + if labels is None: + labels = {model.get_hash(): model.model_id for model in models} + if augment_labels: + + class Identidict(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __getitem__(self, key): + return key + + if parameter_labels is None: + parameter_labels = Identidict() + + model_added_parameters = { + model_hash: ';'.join( + [ + parameter_labels[parameter_id] + for parameter_id in sorted( + model_parameter_diffs[model_hash][0] + ) + ] + ) + for model_hash in model_estimated_parameters + } + model_removed_parameters = { + model_hash: ';'.join( + [ + parameter_labels[parameter_id] + for parameter_id in sorted( + model_parameter_diffs[model_hash][1] + ) + ] + ) + for model_hash in model_estimated_parameters + } + + labels = { + model_hash: ( + label0 + # + ( + # f'\n{model_criterion_values[model_hash]:.2f}' + # if model_hash in model_criterion_values + # else '' + # ) + + ( + ( + # f'\n' + + r'\textcolor{green}{hi}' + # + ('\\textcolor{green}{' if use_tex else '') + + # f'{model_added_parameters[model_hash]}' + # + ('}' if use_tex else '') + if model_added_parameters.get(model_hash, '') + else '' + ) + + ( + # f'\n' + + r'\textcolor{red}{hi}' + # + ('\\textcolor{red}{' if use_tex else '') + + # f'{model_removed_parameters[model_hash]}' + # + ('}' if use_tex else '') + if model_removed_parameters.get(model_hash, '') + else '' + ) + ) + ) + for model_hash, label0 in labels.items() + } + + norm = matplotlib.colors.Normalize( + vmin=min(model_criterion_values.values()), + vmax=max(model_criterion_values.values()), + ) + cmap = matplotlib.cm.get_cmap('cool') + scalar_mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap) + ax.get_figure().colorbar(scalar_mappable, ax=ax) + G = nx.DiGraph( [ (predecessor, successor) @@ -525,30 +632,76 @@ def graph_iteration_layers( ] for layer in ordering ] + + # Labels + # if labels is None: + # labels = { + # model.get_hash(): model.model_id + # + ( + # f'\n{model.get_criterion(criterion):.2f}' + # if criterion is not None + # else '' + # ) + # for model in models + # } + labels[VIRTUAL_INITIAL_MODEL] = labels.get( + VIRTUAL_INITIAL_MODEL, "Virtual\nInitial\nModel" + ) + + missing_labels = { + model.get_hash(): model.model_id + for model in models + if model.get_hash() not in labels + } + missing_labels2 = { + model.predecessor_model_hash: model.predecessor_model_hash + for model in models + if model.predecessor_model_hash not in labels + } + labels.update(missing_labels2) + labels.update(missing_labels) + # for label in missing_labels: + # labels[label] = label + pos = { labels[model_hash]: (X[i], Y[i][j]) for i, layer in enumerate(ordering) for j, model_hash in enumerate(layer) } nx.relabel_nodes(G, mapping=labels, copy=False) - if colors is not None: - if label_diff := set(colors).difference(list(G)): - raise ValueError( - "Colors were provided for the following model labels, but " - "these are not in the graph: {label_diff}" - ) - node_colors = [ - colors.get( - model_label, - draw_networkx_kwargs.get( - 'node_color', default_draw_networkx_kwargs['node_color'] - ), - ) - for model_label in list(G) - ] - draw_networkx_kwargs.update({'node_color': node_colors}) - nx.draw_networkx(G, pos, ax=ax, **draw_networkx_kwargs) + G_hashes = [ + one([k for k, v in labels.items() if v == label]) for label in G.nodes + ] + node_colors = [ + ( + scalar_mappable.to_rgba(model_criterion_values[model_hash]) + if model_hash in model_criterion_values + else NORMAL_NODE_COLOR + ) + for model_hash in G_hashes + ] + + # if colors is not None: + # if label_diff := set(colors).difference(list(G)): + # raise ValueError( + # "Colors were provided for the following model labels, but " + # "these are not in the graph: {label_diff}" + # ) + + # node_colors = [ + # colors.get( + # model_label, + # draw_networkx_kwargs.get( + # 'node_color', default_draw_networkx_kwargs['node_color'] + # ), + # ) + # for model_label in list(G) + # ] + # draw_networkx_kwargs.update({'node_color': node_colors}) + nx.draw_networkx( + G, pos, ax=ax, node_color=node_colors, **draw_networkx_kwargs + ) # Add `n=...` labels N = [len(y) for y in Y] @@ -559,78 +712,80 @@ def graph_iteration_layers( fontsize=draw_networkx_kwargs.get('font_size', 20), ) - # Get selected parameter IDs - # TODO move this logic elsewhere - selected_hashes = set(ancestry.values()) - selected_models = {} - for model in models: - if model.get_hash() in selected_hashes: - selected_models[model.get_hash()] = model - - selected_parameters = { - model_hash: sorted(model.estimated_parameters) - for model_hash, model in selected_models.items() - } - - selected_order = [ - [model_hash for model_hash in layer if model_hash in selected_models] - for layer in ordering - ] - selected_order = [ - None if not model_hash else one(model_hash) - for model_hash in selected_order - ] - - selected_parameter_ids = [] - estimated0 = None - model_hash = None - for model_hash in selected_order: - if model_hash is None: - selected_parameter_ids.append('') - continue - if estimated0 is not None: - new_parameter_ids = list( - set(selected_parameters[model_hash]).symmetric_difference( - estimated0 - ) - ) - new_parameter_names = [] - for new_parameter_id in new_parameter_ids: - # Default to parameter ID, use parameter name if available - new_parameter_name = new_parameter_id - if ( - PARAMETER_NAME - in selected_models[ - model_hash - ].petab_problem.parameter_df.columns - ): - petab_parameter_name = selected_models[ - model_hash - ].petab_problem.parameter_df.loc[ - new_parameter_id, PARAMETER_NAME - ] - if not pd.isna(petab_parameter_name): - new_parameter_name = petab_parameter_name - new_parameter_names.append(new_parameter_name) - new_parameter_names = [ - new_parameter_name.replace('\\\\rightarrow ', '->') - for new_parameter_name in new_parameter_names - ] - selected_parameter_ids.append(sorted(new_parameter_names)) - else: - selected_parameter_ids.append(['']) - estimated0 = selected_parameters[model_hash] - - # Add labels for selected parameters - for x, label in zip(X, selected_parameter_ids): - ax.annotate( - "\n".join(label), - xy=(x, 1.15), - fontsize=draw_networkx_kwargs.get('font_size', 20), - ) + ## Get selected parameter IDs + ## TODO move this logic elsewhere + # selected_hashes = set(ancestry.values()) + # selected_models = {} + # for model in models: + # if model.get_hash() in selected_hashes: + # selected_models[model.get_hash()] = model + + # selected_parameters = { + # model_hash: sorted(model.estimated_parameters) + # for model_hash, model in selected_models.items() + # } + + # selected_order = [ + # [model_hash for model_hash in layer if model_hash in selected_models] + # for layer in ordering + # ] + # selected_order = [ + # None if not model_hash else one(model_hash) + # for model_hash in selected_order + # ] + + # selected_parameter_ids = [] + # estimated0 = None + # model_hash = None + # for model_hash in selected_order: + # if model_hash is None: + # selected_parameter_ids.append('') + # continue + # if estimated0 is not None: + # new_parameter_ids = list( + # set(selected_parameters[model_hash]).symmetric_difference( + # estimated0 + # ) + # ) + # new_parameter_names = [] + # for new_parameter_id in new_parameter_ids: + # # Default to parameter ID, use parameter name if available + # new_parameter_name = new_parameter_id + # if ( + # PARAMETER_NAME + # in selected_models[ + # model_hash + # ].petab_problem.parameter_df.columns + # ): + # petab_parameter_name = selected_models[ + # model_hash + # ].petab_problem.parameter_df.loc[ + # new_parameter_id, PARAMETER_NAME + # ] + # if not pd.isna(petab_parameter_name): + # new_parameter_name = petab_parameter_name + # new_parameter_names.append(new_parameter_name) + # new_parameter_names = [ + # new_parameter_name.replace('\\\\rightarrow ', '->') + # for new_parameter_name in new_parameter_names + # ] + # selected_parameter_ids.append(sorted(new_parameter_names)) + # else: + # selected_parameter_ids.append(['']) + # estimated0 = selected_parameters[model_hash] + + ## Add labels for selected parameters + # for x, label in zip(X, selected_parameter_ids): + # ax.annotate( + # "\n".join(label), + # xy=(x, 1.15), + # fontsize=draw_networkx_kwargs.get('font_size', 20), + # ) # Set margins for the axes so that nodes aren't clipped ax.margins(0.15) ax.axis("off") + # FIXME + plt.rcParams.update(rcParams0) return ax