Skip to content

Commit

Permalink
Refactoring slightly
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom Reichardt committed Jan 31, 2025
1 parent 9d55361 commit 5fb494f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import argparse
import tempfile
from pathlib import Path
from .plot_to_json import get_plot_data, get_events_data, NumpyEncoder
import json

IMG_DIR = Path(__file__).parent / "van_den_heuvel_figures"

Expand All @@ -26,10 +24,10 @@ def main():
parser.add_argument('--outdir', type=str, default='.', help='Path to the directory to save the figures')
parser.add_argument('--dont-show', action='store_false', help='Dont show the plots')
args = parser.parse_args()
run_main_plotter(args.data_path, outdir=args.outdir, show=args.dont_show, as_json=True)
run_main_plotter(args.data_path, outdir=args.outdir, show=args.dont_show)


def run_main_plotter(data_path, outdir='.', show=True, as_json=False):
def run_main_plotter(data_path, outdir='.', show=True):

### Collect the raw data and mask for just the end-of-timesteps events
RawData = h5.File(data_path, 'r')
Expand All @@ -45,15 +43,13 @@ def run_main_plotter(data_path, outdir='.', show=True, as_json=False):
printEvolutionaryHistory(events=events)

### Produce the two plots
detailed = makeDetailedPlots(Data, events, outdir=outdir, as_json=as_json)
vdh = plotVanDenHeuvel(events=events, outdir=outdir, as_json=as_json)

if as_json:
return json.dumps({**detailed, **vdh}, cls=NumpyEncoder)
detailed_fig = makeDetailedPlots(Data, events, outdir=outdir)
vdh_fig, vdh_events = plotVanDenHeuvel(events=events, outdir=outdir)

if show:
plt.show()

return detailed_fig, vdh_fig, vdh_events

def set_font_params():
use_latex = shutil.which("latex") is not None
Expand All @@ -75,7 +71,7 @@ def set_font_params():

####### Functions to organize and call the plotting functions

def makeDetailedPlots(Data=None, events=None, outdir='.', show=True, as_json=False):
def makeDetailedPlots(Data=None, events=None, outdir='.', show=True):
listOfPlots = [plotMassAttributes, plotLengthAttributes, plotStellarTypeAttributes, plotHertzsprungRussell]

events = [event for event in events if event.eventClass != 'Stype'] # want to ignore simple stellar type changes
Expand Down Expand Up @@ -132,12 +128,10 @@ def makeDetailedPlots(Data=None, events=None, outdir='.', show=True, as_json=Fal
fig.suptitle('Detailed evolution for seed = {}'.format(Data['SEED'][()][0]), fontsize=18)
fig.tight_layout(h_pad=1, w_pad=1, rect=(0.08, 0.08, .98, .98), pad=0.) # (left, bottom, right, top)

if as_json:
fig_json = get_plot_data([('mass_plot', ax1), ('length_plot', ax2), ('hr_plot', ax4)])
plt.close('all')
return fig_json

safe_save_figure(fig, f'{outdir}/detailedEvolutionPlot.png', bbox_inches='tight', pad_inches=0, format='png')
if outdir is not None:
safe_save_figure(fig, f'{outdir}/detailedEvolutionPlot.png', bbox_inches='tight', pad_inches=0, format='png')

return fig


######## Plotting functions
Expand Down Expand Up @@ -307,11 +301,9 @@ def get_L(t): # assumes K

return ax.get_legend_handles_labels()

def plotVanDenHeuvel(events=None, outdir='.', as_json=False):
def plotVanDenHeuvel(events=None, outdir='.'):
# Only want events with an associated image
events = [event for event in events if (event.eventImage is not None)]
if as_json:
return get_events_data(events)

num_events = len(events)
fig, axs = plt.subplots(num_events, 1)
Expand Down Expand Up @@ -344,9 +336,11 @@ def plotVanDenHeuvel(events=None, outdir='.', as_json=False):
axs[ii].annotate(chr(ord('@') + 1 + ii), xy=(-0.15, 0.8), xycoords='axes fraction', fontsize=8,
fontweight='bold')

file_path = os.path.join(outdir, 'vanDenHeuvelPlot.eps')
safe_save_figure(fig, file_path, bbox_inches='tight', pad_inches=0, format='eps')
return fig
if outdir is not None:
file_path = os.path.join(outdir, 'vanDenHeuvelPlot.eps')
safe_save_figure(fig, file_path, bbox_inches='tight', pad_inches=0, format='eps')

return fig, events


### Helper functions
Expand Down
23 changes: 23 additions & 0 deletions compas_python_utils/detailed_evolution_plotter/plot_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import matplotlib.pyplot as plt
from matplotlib.lines import _get_dash_pattern, _scale_dashes
from matplotlib.colors import to_hex, to_rgba
from compas_python_utils.detailed_evolution_plotter.plot_detailed_evolution import run_main_plotter


class NumpyEncoder(json.JSONEncoder):
Expand Down Expand Up @@ -372,3 +373,25 @@ def get_events_data(events):
for i, event in enumerate(events)
]
}


def get_plot_json(data_path):
"""Get a JSON string containing the information needed to render line and VDH plots on the
GWLandscape service for a specific COMPAS output file
Parameters
----------
data_path : str or Path
Path to the COMPAS output file
Returns
-------
str
JSON string containing
"""
detailed_fig, _, events = run_main_plotter(data_path, outdir=None, show=False)
axes = detailed_fig.get_axes()
plots_data = get_plot_data([('mass_plot', axes[0]), ('length_plot', axes[1]), ('hr_plot', axes[3])])
events_data = get_events_data(events)
return json.dumps({**plots_data, **events_data}, cls=NumpyEncoder)

0 comments on commit 5fb494f

Please sign in to comment.