diff --git a/compas_python_utils/detailed_evolution_plotter/plot_detailed_evolution.py b/compas_python_utils/detailed_evolution_plotter/plot_detailed_evolution.py index f7479b796..ab520c7f0 100644 --- a/compas_python_utils/detailed_evolution_plotter/plot_detailed_evolution.py +++ b/compas_python_utils/detailed_evolution_plotter/plot_detailed_evolution.py @@ -3,7 +3,6 @@ # Plot the detailed evolution of a COMPAS run # # # ################################################################### - import os import shutil import numpy as np @@ -13,11 +12,12 @@ import argparse import tempfile from pathlib import Path +from .plot_to_json import get_plot_data, get_events_data, NumpyEncoder +import json IMG_DIR = Path(__file__).parent / "van_den_heuvel_figures" - def main(): parser = argparse.ArgumentParser(description='Plot detailed evolution of a COMPAS binary') default_data_path = "./COMPAS_Output/Detailed_Output/BSE_Detailed_Output_0.h5" @@ -26,10 +26,10 @@ def main(): parser.add_argument('--outdir', type=str, default='.', help='Path to the directory to save the figures') parser.add_argument('--dont-show', action='store_false', help='Dont show the plots') args = parser.parse_args() - run_main_plotter(args.data_path, outdir=args.outdir, show=args.dont_show) - + run_main_plotter(args.data_path, outdir=args.outdir, show=args.dont_show, as_json=True) + -def run_main_plotter(data_path, outdir='.', show=True): +def run_main_plotter(data_path, outdir='.', show=True, as_json=False): ### Collect the raw data and mask for just the end-of-timesteps events RawData = h5.File(data_path, 'r') @@ -40,13 +40,17 @@ def run_main_plotter(data_path, outdir='.', show=True): Data.create_dataset(key, data=RawData[key][()][maskRecordType4]) print(np.unique(Data['Record_Type'][()])) - ### Collect the important events in the detailed evolution + ## Collect the important events in the detailed evolution events = allEvents(Data).allEvents # Calculate the events here, for use in plot sizing parameters printEvolutionaryHistory(events=events) ### Produce the two plots - makeDetailedPlots(Data, events, outdir=outdir) - plotVanDenHeuvel(events=events, outdir=outdir) + detailed = makeDetailedPlots(Data, events, outdir=outdir, as_json=as_json) + vdh = plotVanDenHeuvel(events=events, outdir=outdir, as_json=as_json) + + if as_json: + return json.dumps({**detailed, **vdh}, cls=NumpyEncoder) + if show: plt.show() @@ -71,7 +75,7 @@ def set_font_params(): ####### Functions to organize and call the plotting functions -def makeDetailedPlots(Data=None, events=None, outdir='.', show=True): +def makeDetailedPlots(Data=None, events=None, outdir='.', show=True, as_json=False): listOfPlots = [plotMassAttributes, plotLengthAttributes, plotStellarTypeAttributes, plotHertzsprungRussell] events = [event for event in events if event.eventClass != 'Stype'] # want to ignore simple stellar type changes @@ -128,6 +132,11 @@ def makeDetailedPlots(Data=None, events=None, outdir='.', show=True): fig.suptitle('Detailed evolution for seed = {}'.format(Data['SEED'][()][0]), fontsize=18) fig.tight_layout(h_pad=1, w_pad=1, rect=(0.08, 0.08, .98, .98), pad=0.) # (left, bottom, right, top) + if as_json: + fig_json = get_plot_data(fig) + plt.close('all') + return fig_json + safe_save_figure(fig, f'{outdir}/detailedEvolutionPlot.png', bbox_inches='tight', pad_inches=0, format='png') @@ -148,6 +157,8 @@ def plotMassAttributes(ax=None, Data=None, mask=None, **kwargs): ax.set_ylabel(r'Mass $/ \; M_{\odot}$') + ax.tag = "mass_plot" + return ax.get_legend_handles_labels() @@ -163,6 +174,8 @@ def plotLengthAttributes(ax=None, Data=None, mask=None, **kwargs): ax.set_ylabel(r'Radius $/ \; R_{\odot}$') ax.set_yscale('log') + ax.tag = "length_plot" + return ax.get_legend_handles_labels() @@ -296,12 +309,16 @@ def get_L(t): # assumes K ax.legend(framealpha=1, prop={'size': 8}) ax.grid(linestyle=':', c='gray') - return ax.get_legend_handles_labels() + ax.tag = "hr_plot" + return ax.get_legend_handles_labels() -def plotVanDenHeuvel(events=None, outdir='.'): +def plotVanDenHeuvel(events=None, outdir='.', as_json=False): # Only want events with an associated image events = [event for event in events if (event.eventImage is not None)] + if as_json: + return get_events_data(events) + num_events = len(events) fig, axs = plt.subplots(num_events, 1) if num_events == 1: @@ -426,7 +443,7 @@ def __init__(self, Data, index, eventClass, stellarTypeMap, **kwargs): self.eventImage = None self.endState = None # sets the endstate - only relevant if eventClass=='End' - self.eventString = self.getEventDetails(**kwargs) + self.eventString, self.image_num, self.rotate_image = self.getEventDetails(**kwargs) def getEventDetails(self, **kwargs): """ @@ -574,7 +591,7 @@ def getEventDetails(self, **kwargs): if image_num != None: self.eventImage = self.getEventImage(image_num, rotate_image) - return eventString + return eventString, image_num, rotate_image def getEventImage(self, image_num, rotate_image): """ diff --git a/compas_python_utils/detailed_evolution_plotter/plot_to_json.py b/compas_python_utils/detailed_evolution_plotter/plot_to_json.py new file mode 100644 index 000000000..0fae04799 --- /dev/null +++ b/compas_python_utils/detailed_evolution_plotter/plot_to_json.py @@ -0,0 +1,263 @@ + +import json +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.lines import _get_dash_pattern, _scale_dashes +from matplotlib.colors import to_hex, to_rgba + +class NumpyEncoder(json.JSONEncoder): + """ Special json encoder for numpy types """ + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + return json.JSONEncoder.default(self, obj) + +def _artist_is_line(artist): + return isinstance(artist, plt.Line2D) and not artist.get_label().startswith("_child") + +def _artist_is_ref_line(artist): + return isinstance(artist, plt.Line2D) and artist.get_label().startswith("_child") + +def _artist_is_text(artist): + return isinstance(artist, plt.Text) and artist.get_text() + +def get_artist_colour(artist): + """Get the colour of the artist in hex format + + Parameters + ---------- + artist : pyplot.Line2D or pyplot.Text + Pyplot artist for which to obtain the colour + + Returns + ------- + str + String with the colour in hex format + """ + return to_hex(to_rgba(artist.get_color(), artist.get_alpha()), keep_alpha=True) + +def get_ref_line_data(ref_line, label): + """Obtain the data and metadata needed to render reference lines on the GWLandscape service + + Parameters + ---------- + ref_line : pyplot.Line2D + Reference line object to be rendered + label : str + Name of the reference line + + Returns + ------- + dict + Dictionary with the data and metadata needed to render the reference line + """ + ref_line_meta = get_line_meta(ref_line, "y", label) + + xs, ys = ref_line.get_xdata(), ref_line.get_ydata() + + if xs[0] == xs[-1] and len(set(xs)) == 1: + ref_line_meta["type"] = "vline" + elif ys[0] == ys[-1] and len(set(ys)) == 1: + ref_line_meta["type"] = "hline" + else: + ref_line_meta["type"] = "ref" + + ref_line_meta["points"] = [ + {"x": xs[0], "y": ys[0]}, + {"x": xs[-1], "y": ys[-1]}, + ] + return ref_line_meta + +def get_text_data(text): + """Obtain the data needed to render text elements on the GWLandscape service + + Parameters + ---------- + text : pyplot.Text + Text object to be rendered + + Returns + ------- + dict + Dictionary containing the data and metadata necessary to render the text elements + """ + return { + "label": text.get_text(), + "x": text.get_position()[0], + "y": text.get_position()[1], + "colour": get_artist_colour(text) + } + +def get_line_dashes(line): + """Obtain the dash pattern of a line. This uses a private attribute of the Line2D class or + a private function from pyplot + + Parameters + ---------- + line : _type_pyplot.Line2D + Line object for which to obtain the dash pattern + + Returns + ------- + str or None + String of numbers separated by spaces representing the lengths of dashes and spaces in the pattern + """ + if hasattr(line, "_dash_pattern"): + _, dashes = line._dash_pattern + else: + _, dashes = _scale_dashes(*_get_dash_pattern(line.get_linestyle()), line.get_linewidth()) + + return " ".join(map(str, dashes)) if dashes else None + +def get_line_meta(line, y_key, label=None): + """Get the metadata for a line + + Parameters + ---------- + line : pyplot.Line2D + A line from a line plot + y_key : str + The key for the data, which will be used to identify it on the frontend + label : str, optional + The name of the line, by default None. If None, the label will be obtained from the Line2D object + + Returns + ------- + dict + Dictionary containing the metadata necessary to render the line properly on the frontend + """ + return { + "colour": get_artist_colour(line), + "dashes": get_line_dashes(line), + "width": line.get_linewidth(), + "xKey": "x", + "yKey": y_key, + "label": line.get_label() if label is None else label, + "type": "data" + } + +def get_line_groups(lines): + """Takes a list of Line2D objects and organises them into groups based on whether or not they have the same + x data. Each group contains the x data for the group, a list of y axis data for each line in the group, and + the metadata for each line. + + + Parameters + ---------- + lines : list + List of pyplot Line2D objects + + Returns + ------- + list + List of groups containing the x data of the group and the y data and metadata for all lines in the group + """ + groups = [] + for i, line in enumerate(lines): + x_data = line.get_xdata() + y_data = line.get_ydata() + meta = get_line_meta(line, f"y{i}") + for group in groups: + if np.array_equal(group["x_data"], x_data): + group["y_data"].append(y_data) + group["meta"].append(meta) + break + else: + groups.append({"x_data": x_data, "y_data": [y_data], "meta": [meta]}) + return groups + +def get_plot_data(fig): + """Takes a pyplot Figure instance and outputs JSON data to render the plots on the GWLandscape service + + Parameters + ---------- + fig : pyplot.Figure + Pyplot Figure to replicate + + Returns + ------- + dict + Dictionary containing necessary data to render the plots on a webpage + """ + json_data = { + "plots": {} + } + + for ax in fig.get_axes(): + if not hasattr(ax, "tag"): + continue + + if ax.xaxis_inverted(): + ax.invert_xaxis() + + artists = ax.get_children() + ref_lines = [ + get_ref_line_data(ref_line, f"refLine{i}") + for i, ref_line in enumerate(filter(_artist_is_ref_line, artists)) + ] + texts = [get_text_data(text) for text in filter(_artist_is_text, artists)] + + line_groups = get_line_groups(filter(_artist_is_line, artists)) + + groups = [] + for line_group in line_groups: + group = {"data": [], "meta": line_group["meta"]} + for i, x in enumerate(line_group["x_data"]): + row = {line["yKey"]: line_group["y_data"][j][i] for j, line in enumerate(group["meta"])} + row[group["meta"][0]["xKey"]] = x + group["data"].append(row) + groups.append(group) + + json_data["plots"][ax.tag] = { + "meta": { + "xAxis": { + "label": ax.get_xlabel(), + "ticks": ax.get_xticks(), + "limits": ax.get_xlim(), + "scale": ax.get_xscale() + }, + "yAxis": { + "label": ax.get_ylabel(), + "ticks": ax.get_yticks(), + "limits": ax.get_ylim(), + "scale": ax.get_yscale() + }, + }, + "groups": groups, + "refLines": ref_lines, + "texts": texts + } + + return json_data + +def get_events_data(events): + """Uses a list of Events to generate a JSON structure for rendering Van Den Heuvel plots on the GWLandscape service + + Parameters + ---------- + events : list + List of Events + + Returns + ------- + dict + Dictionary containing data necessary to render VDH plot on a webpage + """ + return { + "events": [ + { + "eventChar": chr(ord('@') + 1 + i), + "time": event.time, + "a": [event.aprev, event.a], + "m1": [event.m1prev, event.m1], + "m2": [event.m2prev, event.m2], + "eventString": event.eventString, + "imageNum": event.image_num, + "flipImage": event.rotate_image, + } for i, event in enumerate(events) + ] + } \ No newline at end of file