diff --git a/compas_python_utils/detailed_evolution_plotter/plot_detailed_evolution.py b/compas_python_utils/detailed_evolution_plotter/plot_detailed_evolution.py index 88deb6a1..9ce67eec 100644 --- a/compas_python_utils/detailed_evolution_plotter/plot_detailed_evolution.py +++ b/compas_python_utils/detailed_evolution_plotter/plot_detailed_evolution.py @@ -12,8 +12,6 @@ import argparse import tempfile from pathlib import Path -from .plot_to_json import get_plot_data, get_events_data, NumpyEncoder -import json IMG_DIR = Path(__file__).parent / "van_den_heuvel_figures" @@ -26,10 +24,10 @@ def main(): parser.add_argument('--outdir', type=str, default='.', help='Path to the directory to save the figures') parser.add_argument('--dont-show', action='store_false', help='Dont show the plots') args = parser.parse_args() - run_main_plotter(args.data_path, outdir=args.outdir, show=args.dont_show, as_json=True) + run_main_plotter(args.data_path, outdir=args.outdir, show=args.dont_show) -def run_main_plotter(data_path, outdir='.', show=True, as_json=False): +def run_main_plotter(data_path, outdir='.', show=True): ### Collect the raw data and mask for just the end-of-timesteps events RawData = h5.File(data_path, 'r') @@ -45,15 +43,13 @@ def run_main_plotter(data_path, outdir='.', show=True, as_json=False): printEvolutionaryHistory(events=events) ### Produce the two plots - detailed = makeDetailedPlots(Data, events, outdir=outdir, as_json=as_json) - vdh = plotVanDenHeuvel(events=events, outdir=outdir, as_json=as_json) - - if as_json: - return json.dumps({**detailed, **vdh}, cls=NumpyEncoder) + detailed_fig = makeDetailedPlots(Data, events, outdir=outdir) + vdh_fig, vdh_events = plotVanDenHeuvel(events=events, outdir=outdir) if show: plt.show() + return detailed_fig, vdh_fig, vdh_events def set_font_params(): use_latex = shutil.which("latex") is not None @@ -75,7 +71,7 @@ def set_font_params(): ####### Functions to organize and call the plotting functions -def makeDetailedPlots(Data=None, events=None, outdir='.', show=True, as_json=False): +def makeDetailedPlots(Data=None, events=None, outdir='.', show=True): listOfPlots = [plotMassAttributes, plotLengthAttributes, plotStellarTypeAttributes, plotHertzsprungRussell] events = [event for event in events if event.eventClass != 'Stype'] # want to ignore simple stellar type changes @@ -132,12 +128,10 @@ def makeDetailedPlots(Data=None, events=None, outdir='.', show=True, as_json=Fal fig.suptitle('Detailed evolution for seed = {}'.format(Data['SEED'][()][0]), fontsize=18) fig.tight_layout(h_pad=1, w_pad=1, rect=(0.08, 0.08, .98, .98), pad=0.) # (left, bottom, right, top) - if as_json: - fig_json = get_plot_data([('mass_plot', ax1), ('length_plot', ax2), ('hr_plot', ax4)]) - plt.close('all') - return fig_json - - safe_save_figure(fig, f'{outdir}/detailedEvolutionPlot.png', bbox_inches='tight', pad_inches=0, format='png') + if outdir is not None: + safe_save_figure(fig, f'{outdir}/detailedEvolutionPlot.png', bbox_inches='tight', pad_inches=0, format='png') + + return fig ######## Plotting functions @@ -307,11 +301,9 @@ def get_L(t): # assumes K return ax.get_legend_handles_labels() -def plotVanDenHeuvel(events=None, outdir='.', as_json=False): +def plotVanDenHeuvel(events=None, outdir='.'): # Only want events with an associated image events = [event for event in events if (event.eventImage is not None)] - if as_json: - return get_events_data(events) num_events = len(events) fig, axs = plt.subplots(num_events, 1) @@ -344,9 +336,11 @@ def plotVanDenHeuvel(events=None, outdir='.', as_json=False): axs[ii].annotate(chr(ord('@') + 1 + ii), xy=(-0.15, 0.8), xycoords='axes fraction', fontsize=8, fontweight='bold') - file_path = os.path.join(outdir, 'vanDenHeuvelPlot.eps') - safe_save_figure(fig, file_path, bbox_inches='tight', pad_inches=0, format='eps') - return fig + if outdir is not None: + file_path = os.path.join(outdir, 'vanDenHeuvelPlot.eps') + safe_save_figure(fig, file_path, bbox_inches='tight', pad_inches=0, format='eps') + + return fig, events ### Helper functions diff --git a/compas_python_utils/detailed_evolution_plotter/plot_to_json.py b/compas_python_utils/detailed_evolution_plotter/plot_to_json.py index c78eeeab..283130d6 100644 --- a/compas_python_utils/detailed_evolution_plotter/plot_to_json.py +++ b/compas_python_utils/detailed_evolution_plotter/plot_to_json.py @@ -4,6 +4,7 @@ import matplotlib.pyplot as plt from matplotlib.lines import _get_dash_pattern, _scale_dashes from matplotlib.colors import to_hex, to_rgba +from compas_python_utils.detailed_evolution_plotter.plot_detailed_evolution import run_main_plotter class NumpyEncoder(json.JSONEncoder): @@ -372,3 +373,25 @@ def get_events_data(events): for i, event in enumerate(events) ] } + + +def get_plot_json(data_path): + """Get a JSON string containing the information needed to render line and VDH plots on the + GWLandscape service for a specific COMPAS output file + + Parameters + ---------- + data_path : str or Path + Path to the COMPAS output file + + Returns + ------- + str + JSON string containing + """ + detailed_fig, _, events = run_main_plotter(data_path, outdir=None, show=False) + axes = detailed_fig.get_axes() + plots_data = get_plot_data([('mass_plot', axes[0]), ('length_plot', axes[1]), ('hr_plot', axes[3])]) + events_data = get_events_data(events) + return json.dumps({**plots_data, **events_data}, cls=NumpyEncoder) + \ No newline at end of file