Skip to content

Commit

Permalink
Initial work on plotting backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom Reichardt committed Jan 16, 2025
1 parent a54570e commit e1052db
Show file tree
Hide file tree
Showing 2 changed files with 293 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Plot the detailed evolution of a COMPAS run #
# #
###################################################################

import os
import shutil
import numpy as np
Expand All @@ -13,11 +12,12 @@
import argparse
import tempfile
from pathlib import Path
from .plot_to_json import get_plot_data, get_events_data, NumpyEncoder
import json

IMG_DIR = Path(__file__).parent / "van_den_heuvel_figures"



def main():
parser = argparse.ArgumentParser(description='Plot detailed evolution of a COMPAS binary')
default_data_path = "./COMPAS_Output/Detailed_Output/BSE_Detailed_Output_0.h5"
Expand All @@ -26,10 +26,10 @@ def main():
parser.add_argument('--outdir', type=str, default='.', help='Path to the directory to save the figures')
parser.add_argument('--dont-show', action='store_false', help='Dont show the plots')
args = parser.parse_args()
run_main_plotter(args.data_path, outdir=args.outdir, show=args.dont_show)

run_main_plotter(args.data_path, outdir=args.outdir, show=args.dont_show, as_json=True)

def run_main_plotter(data_path, outdir='.', show=True):
def run_main_plotter(data_path, outdir='.', show=True, as_json=False):

### Collect the raw data and mask for just the end-of-timesteps events
RawData = h5.File(data_path, 'r')
Expand All @@ -40,13 +40,17 @@ def run_main_plotter(data_path, outdir='.', show=True):
Data.create_dataset(key, data=RawData[key][()][maskRecordType4])
print(np.unique(Data['Record_Type'][()]))

### Collect the important events in the detailed evolution
## Collect the important events in the detailed evolution
events = allEvents(Data).allEvents # Calculate the events here, for use in plot sizing parameters
printEvolutionaryHistory(events=events)

### Produce the two plots
makeDetailedPlots(Data, events, outdir=outdir)
plotVanDenHeuvel(events=events, outdir=outdir)
detailed = makeDetailedPlots(Data, events, outdir=outdir, as_json=as_json)
vdh = plotVanDenHeuvel(events=events, outdir=outdir, as_json=as_json)

if as_json:
return json.dumps({**detailed, **vdh}, cls=NumpyEncoder)

if show:
plt.show()

Expand All @@ -71,7 +75,7 @@ def set_font_params():

####### Functions to organize and call the plotting functions

def makeDetailedPlots(Data=None, events=None, outdir='.', show=True):
def makeDetailedPlots(Data=None, events=None, outdir='.', show=True, as_json=False):
listOfPlots = [plotMassAttributes, plotLengthAttributes, plotStellarTypeAttributes, plotHertzsprungRussell]

events = [event for event in events if event.eventClass != 'Stype'] # want to ignore simple stellar type changes
Expand Down Expand Up @@ -128,6 +132,11 @@ def makeDetailedPlots(Data=None, events=None, outdir='.', show=True):
fig.suptitle('Detailed evolution for seed = {}'.format(Data['SEED'][()][0]), fontsize=18)
fig.tight_layout(h_pad=1, w_pad=1, rect=(0.08, 0.08, .98, .98), pad=0.) # (left, bottom, right, top)

if as_json:
fig_json = get_plot_data(fig)
plt.close('all')
return fig_json

safe_save_figure(fig, f'{outdir}/detailedEvolutionPlot.png', bbox_inches='tight', pad_inches=0, format='png')


Expand All @@ -148,6 +157,8 @@ def plotMassAttributes(ax=None, Data=None, mask=None, **kwargs):

ax.set_ylabel(r'Mass $/ \; M_{\odot}$')

ax.tag = "mass_plot"

return ax.get_legend_handles_labels()


Expand All @@ -163,6 +174,8 @@ def plotLengthAttributes(ax=None, Data=None, mask=None, **kwargs):
ax.set_ylabel(r'Radius $/ \; R_{\odot}$')
ax.set_yscale('log')

ax.tag = "length_plot"

return ax.get_legend_handles_labels()


Expand Down Expand Up @@ -296,12 +309,16 @@ def get_L(t): # assumes K
ax.legend(framealpha=1, prop={'size': 8})
ax.grid(linestyle=':', c='gray')

return ax.get_legend_handles_labels()
ax.tag = "hr_plot"

return ax.get_legend_handles_labels()

def plotVanDenHeuvel(events=None, outdir='.'):
def plotVanDenHeuvel(events=None, outdir='.', as_json=False):
# Only want events with an associated image
events = [event for event in events if (event.eventImage is not None)]
if as_json:
return get_events_data(events)

num_events = len(events)
fig, axs = plt.subplots(num_events, 1)
if num_events == 1:
Expand Down Expand Up @@ -426,7 +443,7 @@ def __init__(self, Data, index, eventClass, stellarTypeMap, **kwargs):

self.eventImage = None
self.endState = None # sets the endstate - only relevant if eventClass=='End'
self.eventString = self.getEventDetails(**kwargs)
self.eventString, self.image_num, self.rotate_image = self.getEventDetails(**kwargs)

def getEventDetails(self, **kwargs):
"""
Expand Down Expand Up @@ -574,7 +591,7 @@ def getEventDetails(self, **kwargs):
if image_num != None:
self.eventImage = self.getEventImage(image_num, rotate_image)

return eventString
return eventString, image_num, rotate_image

def getEventImage(self, image_num, rotate_image):
"""
Expand Down
263 changes: 263 additions & 0 deletions compas_python_utils/detailed_evolution_plotter/plot_to_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@

import json
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import _get_dash_pattern, _scale_dashes
from matplotlib.colors import to_hex, to_rgba

class NumpyEncoder(json.JSONEncoder):
""" Special json encoder for numpy types """
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)

def _artist_is_line(artist):
return isinstance(artist, plt.Line2D) and not artist.get_label().startswith("_child")

def _artist_is_ref_line(artist):
return isinstance(artist, plt.Line2D) and artist.get_label().startswith("_child")

def _artist_is_text(artist):
return isinstance(artist, plt.Text) and artist.get_text()

def get_artist_colour(artist):
"""Get the colour of the artist in hex format
Parameters
----------
artist : pyplot.Line2D or pyplot.Text
Pyplot artist for which to obtain the colour
Returns
-------
str
String with the colour in hex format
"""
return to_hex(to_rgba(artist.get_color(), artist.get_alpha()), keep_alpha=True)

def get_ref_line_data(ref_line, label):
"""Obtain the data and metadata needed to render reference lines on the GWLandscape service
Parameters
----------
ref_line : pyplot.Line2D
Reference line object to be rendered
label : str
Name of the reference line
Returns
-------
dict
Dictionary with the data and metadata needed to render the reference line
"""
ref_line_meta = get_line_meta(ref_line, "y", label)

xs, ys = ref_line.get_xdata(), ref_line.get_ydata()

if xs[0] == xs[-1] and len(set(xs)) == 1:
ref_line_meta["type"] = "vline"
elif ys[0] == ys[-1] and len(set(ys)) == 1:
ref_line_meta["type"] = "hline"
else:
ref_line_meta["type"] = "ref"

ref_line_meta["points"] = [
{"x": xs[0], "y": ys[0]},
{"x": xs[-1], "y": ys[-1]},
]
return ref_line_meta

def get_text_data(text):
"""Obtain the data needed to render text elements on the GWLandscape service
Parameters
----------
text : pyplot.Text
Text object to be rendered
Returns
-------
dict
Dictionary containing the data and metadata necessary to render the text elements
"""
return {
"label": text.get_text(),
"x": text.get_position()[0],
"y": text.get_position()[1],
"colour": get_artist_colour(text)
}

def get_line_dashes(line):
"""Obtain the dash pattern of a line. This uses a private attribute of the Line2D class or
a private function from pyplot
Parameters
----------
line : _type_pyplot.Line2D
Line object for which to obtain the dash pattern
Returns
-------
str or None
String of numbers separated by spaces representing the lengths of dashes and spaces in the pattern
"""
if hasattr(line, "_dash_pattern"):
_, dashes = line._dash_pattern
else:
_, dashes = _scale_dashes(*_get_dash_pattern(line.get_linestyle()), line.get_linewidth())

return " ".join(map(str, dashes)) if dashes else None

def get_line_meta(line, y_key, label=None):
"""Get the metadata for a line
Parameters
----------
line : pyplot.Line2D
A line from a line plot
y_key : str
The key for the data, which will be used to identify it on the frontend
label : str, optional
The name of the line, by default None. If None, the label will be obtained from the Line2D object
Returns
-------
dict
Dictionary containing the metadata necessary to render the line properly on the frontend
"""
return {
"colour": get_artist_colour(line),
"dashes": get_line_dashes(line),
"width": line.get_linewidth(),
"xKey": "x",
"yKey": y_key,
"label": line.get_label() if label is None else label,
"type": "data"
}

def get_line_groups(lines):
"""Takes a list of Line2D objects and organises them into groups based on whether or not they have the same
x data. Each group contains the x data for the group, a list of y axis data for each line in the group, and
the metadata for each line.
Parameters
----------
lines : list
List of pyplot Line2D objects
Returns
-------
list
List of groups containing the x data of the group and the y data and metadata for all lines in the group
"""
groups = []
for i, line in enumerate(lines):
x_data = line.get_xdata()
y_data = line.get_ydata()
meta = get_line_meta(line, f"y{i}")
for group in groups:
if np.array_equal(group["x_data"], x_data):
group["y_data"].append(y_data)
group["meta"].append(meta)
break
else:
groups.append({"x_data": x_data, "y_data": [y_data], "meta": [meta]})
return groups

def get_plot_data(fig):
"""Takes a pyplot Figure instance and outputs JSON data to render the plots on the GWLandscape service
Parameters
----------
fig : pyplot.Figure
Pyplot Figure to replicate
Returns
-------
dict
Dictionary containing necessary data to render the plots on a webpage
"""
json_data = {
"plots": {}
}

for ax in fig.get_axes():
if not hasattr(ax, "tag"):
continue

if ax.xaxis_inverted():
ax.invert_xaxis()

artists = ax.get_children()
ref_lines = [
get_ref_line_data(ref_line, f"refLine{i}")
for i, ref_line in enumerate(filter(_artist_is_ref_line, artists))
]
texts = [get_text_data(text) for text in filter(_artist_is_text, artists)]

line_groups = get_line_groups(filter(_artist_is_line, artists))

groups = []
for line_group in line_groups:
group = {"data": [], "meta": line_group["meta"]}
for i, x in enumerate(line_group["x_data"]):
row = {line["yKey"]: line_group["y_data"][j][i] for j, line in enumerate(group["meta"])}
row[group["meta"][0]["xKey"]] = x
group["data"].append(row)
groups.append(group)

json_data["plots"][ax.tag] = {
"meta": {
"xAxis": {
"label": ax.get_xlabel(),
"ticks": ax.get_xticks(),
"limits": ax.get_xlim(),
"scale": ax.get_xscale()
},
"yAxis": {
"label": ax.get_ylabel(),
"ticks": ax.get_yticks(),
"limits": ax.get_ylim(),
"scale": ax.get_yscale()
},
},
"groups": groups,
"refLines": ref_lines,
"texts": texts
}

return json_data

def get_events_data(events):
"""Uses a list of Events to generate a JSON structure for rendering Van Den Heuvel plots on the GWLandscape service
Parameters
----------
events : list
List of Events
Returns
-------
dict
Dictionary containing data necessary to render VDH plot on a webpage
"""
return {
"events": [
{
"eventChar": chr(ord('@') + 1 + i),
"time": event.time,
"a": [event.aprev, event.a],
"m1": [event.m1prev, event.m1],
"m2": [event.m2prev, event.m2],
"eventString": event.eventString,
"imageNum": event.image_num,
"flipImage": event.rotate_image,
} for i, event in enumerate(events)
]
}

0 comments on commit e1052db

Please sign in to comment.