Skip to content

Commit

Permalink
Added tests, fixed the format of the output json, tidied the code
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom Reichardt committed Jan 30, 2025
1 parent ebb45bd commit 9d55361
Show file tree
Hide file tree
Showing 3 changed files with 537 additions and 304 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def makeDetailedPlots(Data=None, events=None, outdir='.', show=True, as_json=Fal
fig.tight_layout(h_pad=1, w_pad=1, rect=(0.08, 0.08, .98, .98), pad=0.) # (left, bottom, right, top)

if as_json:
fig_json = get_plot_data([('mass', ax1), ('length', ax2), ('hr', ax4)])
fig_json = get_plot_data([('mass_plot', ax1), ('length_plot', ax2), ('hr_plot', ax4)])
plt.close('all')
return fig_json

Expand Down
206 changes: 152 additions & 54 deletions compas_python_utils/detailed_evolution_plotter/plot_to_json.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,116 @@

import json
import re
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import _get_dash_pattern, _scale_dashes
from matplotlib.colors import to_hex, to_rgba


class NumpyEncoder(json.JSONEncoder):
""" Special json encoder for numpy types """
"""JSON Encoder to help with numpy types"""

def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
return round(float(obj), 6)
elif isinstance(obj, np.ndarray):
if np.issubdtype(obj.dtype, np.floating):
obj = np.round(obj, decimals=6)
return obj.tolist()
elif isinstance(obj, np.bool_):
return bool(obj)
return json.JSONEncoder.default(self, obj)


def _artist_is_line(artist):
return isinstance(artist, plt.Line2D) and not artist.get_label().startswith("_child")
return isinstance(artist, plt.Line2D) and not artist.get_label().startswith(
"_child"
)


def _artist_is_ref_line(artist):
return isinstance(artist, plt.Line2D) and artist.get_label().startswith("_child")


def _artist_is_text(artist):
return isinstance(artist, plt.Text) and artist.get_text()


def replace_latex_commands(latex_expr):
commands = [(r"\\odot", "⊙"), (r"\\;", " "), (r"\s+", " ")]
for pattern, replacement in commands:
latex_expr = re.sub(pattern, replacement, latex_expr)
return latex_expr


def parse_latex_math(latex_expr):
"""Use regex to parse latex maths into tokens of subscript, superscript or neither
Parameters
----------
latex_expr : str
The input latex maths expression
Returns
-------
list
List of tuples, where the first item of the tuple is one of 'superscript', 'subscript' or 'text',
and the second item of the tuple is the corresponding part of the expression
"""
pattern = (
r"(?P<subscript>_(?P<sub_single>[^{_\\\^}])|_{(?P<sub_braces>.+?)}|_(?P<sub_command>\\[a-zA-Z]+))|"
r"(?P<superscript>\^(?P<sup_single>[^{_\\\^}])|\^{(?P<sup_braces>.+?)}|\^(?P<sup_command>\\[a-zA-Z]+))|"
r"(?P<text>[^{_\^}]+)"
)

groups = []
for match in re.finditer(pattern, latex_expr):
if match.group("subscript"):
for group in ["sub_single", "sub_braces", "sub_command"]:
if match.group(group):
groups.append(
("subscript", replace_latex_commands(match.group(group)))
)
elif match.group("superscript"):
for group in ["sup_single", "sup_braces", "sup_command"]:
if match.group(group):
groups.append(
("superscript", replace_latex_commands(match.group(group)))
)
elif match.group("text"):
groups.append(("text", replace_latex_commands(match.group("text"))))

return groups


def strip_latex(input_string):
"""Takes a string, which may have latex maths expressions inside (i.e. surrounded by $),
and returns a list of tokens determining whether the text should be sub-/superscript,
and replacing some latex commands such as \\odot
Parameters
----------
input_string : str
String to be parsed
Returns
-------
list
List of tuples, where the first item of the tuple is one of 'superscript', 'subscript' or 'text',
and the second item of the tuple is the corresponding part of the string
"""
split = re.findall(r"\$([^$]+)\$|([^$]+)", input_string)

parsed = []
for part in split:
if part[0]:
parsed.extend(parse_latex_math(part[0]))
elif part[1]:
parsed.append(("text", part[1]))
return parsed


def get_artist_colour(artist):
"""Get the colour of the artist in hex format
Expand All @@ -37,7 +123,7 @@ def get_artist_colour(artist):
-------
str
String with the colour in hex format
"""
"""
return to_hex(to_rgba(artist.get_color(), artist.get_alpha()), keep_alpha=True)


Expand All @@ -54,12 +140,14 @@ def get_line_dashes(line):
-------
str or None
String of numbers separated by spaces representing the lengths of dashes and spaces in the pattern
"""
"""
if hasattr(line, "_dash_pattern"):
_, dashes = line._dash_pattern
else:
_, dashes = _scale_dashes(*_get_dash_pattern(line.get_linestyle()), line.get_linewidth())

_, dashes = _scale_dashes(
*_get_dash_pattern(line.get_linestyle()), line.get_linewidth()
)

return " ".join(map(str, dashes)) if dashes else None


Expand All @@ -79,15 +167,15 @@ def get_line_meta(line, y_key, label=None):
-------
dict
Dictionary containing the metadata necessary to render the line properly on the frontend
"""
"""
return {
"colour": get_artist_colour(line),
"dashes": get_line_dashes(line),
"width": line.get_linewidth(),
"xKey": "x",
"yKey": y_key,
"label": line.get_label() if label is None else label,
"type": "data"
"type": "data",
}


Expand All @@ -105,7 +193,7 @@ def get_ref_line_data(ref_line, label):
-------
dict
Dictionary with the data and metadata needed to render the reference line
"""
"""
ref_line_meta = get_line_meta(ref_line, "y", label)

xs, ys = ref_line.get_xdata(), ref_line.get_ydata()
Expand All @@ -117,19 +205,15 @@ def get_ref_line_data(ref_line, label):
else:
ref_line_meta["type"] = "ref"

if ref_line_meta in ["vline", "hline"]:
ref_line_data = [
{"x": xs[0], "y": ys[0]},
{"x": xs[-1], "y": ys[-1]},
]
else:
ref_line_data = [{"x": x, "y": y} for x, y in ref_line.get_xydata()]

return {
"meta": ref_line_meta,
"data": ref_line_data
"data": [
{"x": xs[0], "y": ys[0]},
{"x": xs[-1], "y": ys[-1]},
],
}


def get_text_data(text):
"""Obtain the data needed to render text elements on the GWLandscape service
Expand All @@ -142,16 +226,16 @@ def get_text_data(text):
-------
dict
Dictionary containing the data and metadata necessary to render the text elements
"""
"""
return {
"meta": {
"label": text.get_text(),
"colour": get_artist_colour(text)
"label": strip_latex(text.get_text()),
"colour": get_artist_colour(text),
},
"data": {
"x": text.get_position()[0],
"y": text.get_position()[1],
}
},
}


Expand All @@ -170,7 +254,7 @@ def get_line_groups(lines):
-------
list
List of groups containing the x data of the group and the y data and metadata for all lines in the group
"""
"""
groups = []
for i, line in enumerate(lines):
x_data = line.get_xdata()
Expand All @@ -185,6 +269,7 @@ def get_line_groups(lines):
groups.append({"x_data": x_data, "y_data": [y_data], "meta": [meta]})
return groups


def get_plot_data(axes_map):
"""Takes a pyplot Figure instance and outputs JSON data to render the plots on the GWLandscape service
Expand All @@ -198,15 +283,18 @@ def get_plot_data(axes_map):
-------
dict
Dictionary containing necessary data to render the plots on a webpage
"""
json_data = {
"plots": {}
}
"""
json_data = {"plots": []}

for label, ax in axes_map:
if ax.xaxis_inverted():
x_inverted = ax.xaxis_inverted()
if x_inverted:
ax.invert_xaxis()

y_inverted = ax.yaxis_inverted()
if y_inverted:
ax.invert_yaxis()

artists = ax.get_children()
ref_lines = [
get_ref_line_data(ref_line, f"refLine{i}")
Expand All @@ -220,33 +308,42 @@ def get_plot_data(axes_map):
for line_group in line_groups:
group = {"data": [], "meta": line_group["meta"]}
for i, x in enumerate(line_group["x_data"]):
row = {line["yKey"]: line_group["y_data"][j][i] for j, line in enumerate(group["meta"])}
row = {
line["yKey"]: line_group["y_data"][j][i]
for j, line in enumerate(group["meta"])
}
row[group["meta"][0]["xKey"]] = x
group["data"].append(row)
groups.append(group)

json_data["plots"][label] = {
"meta": {
"xAxis": {
"label": ax.get_xlabel(),
"ticks": ax.get_xticks(),
"limits": ax.get_xlim(),
"scale": ax.get_xscale()
},
"yAxis": {
"label": ax.get_ylabel(),
"ticks": ax.get_yticks(),
"limits": ax.get_ylim(),
"scale": ax.get_yscale()
json_data["plots"].append(
{
"meta": {
"label": label,
"xAxis": {
"label": strip_latex(ax.get_xlabel()),
"ticks": ax.get_xticks(),
"limits": ax.get_xlim(),
"scale": ax.get_xscale(),
"inverted": x_inverted,
},
"yAxis": {
"label": strip_latex(ax.get_ylabel()),
"ticks": ax.get_yticks(),
"limits": ax.get_ylim(),
"scale": ax.get_yscale(),
"inverted": y_inverted,
},
},
},
"groups": groups,
"refLines": ref_lines,
"texts": texts
}
"groups": groups,
"refLines": ref_lines,
"texts": texts,
}
)

return json_data


def get_events_data(events):
"""Uses a list of Events to generate a JSON structure for rendering Van Den Heuvel plots on the GWLandscape service
Expand All @@ -259,18 +356,19 @@ def get_events_data(events):
-------
dict
Dictionary containing data necessary to render VDH plot on a webpage
"""
"""
return {
"events": [
{
"eventChar": chr(ord('@') + 1 + i),
"eventChar": chr(ord("@") + 1 + i),
"time": event.time,
"a": [event.aprev, event.a],
"m1": [event.m1prev, event.m1],
"m2": [event.m2prev, event.m2],
"eventString": event.eventString,
"imageNum": event.image_num,
"flipImage": event.rotate_image,
} for i, event in enumerate(events)
}
for i, event in enumerate(events)
]
}
}
Loading

0 comments on commit 9d55361

Please sign in to comment.