Skip to content

Commit

Permalink
TOC
Browse files Browse the repository at this point in the history
  • Loading branch information
Awallace3 committed Jul 9, 2024
1 parent d37a40b commit c1da1b4
Show file tree
Hide file tree
Showing 3 changed files with 277 additions and 5 deletions.
10 changes: 5 additions & 5 deletions plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ def main():
subprocess.call("mv basis_study.pkl plots/basis_study.pkl", shell=True)
df = pd.read_pickle(df_name)
# print(df.columns.values)
df = src.plotting.plot_basis_sets_d4_TT(
df,
True,
)
return
df = src.plotting.plotting_setup(
(df, df_name),
False,
)
return
df = src.plotting.plot_basis_sets_d4_TT(
df,
True,
)
df = src.plotting.plot_basis_sets_d4(
df,
False,
Expand Down
Binary file added plots/basis_study_ATM_TOC_dbs_violin.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
272 changes: 272 additions & 0 deletions src/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,36 @@ def plotting_setup(
transparent=True,
# figure_size=(6, 6),
)
plot_violin_d3_d4_ALL_zoomed_min_max_TOC(
df,
{
"SAPT0/jDZ": "SAPT0_jdz_3_IE_diff",
"SAPT0/aDZ": "SAPT0_adz_3_IE_diff",
"SAPT0-D3/jDZ": "SAPT0_jdz_3_IE_d3_diff",
#SAPT "0-D3MBJ(ATM)/jDZ": "jdz_diff_d3mbj_atm",
"SAPT0-D3/aDZ": "SAPT0_adz_3_IE_d3_diff",
#SAPT "0-D3MBJ(ATM)/aDZ": "adz_diff_d3mbj_atm",
"SAPT0-D4/aDZ": "SAPT0_adz_3_IE_d4_diff",
# "0-D4(ATM)/aDZ": "SAPT0_dz_3_IE_ATM_SHARED_d4_diff",
# "0-D4(ATM)/aDZ": "SAPT0_adz_ATM_opt_all_diff",
# "0-D4(ATMu)/aDZ": "adz_diff_d4_ATM", # (2B ATM) renamed to (ATMu)
# "0-D4(2B@G ATM)/aDZ": "adz_diff_d4_2B@ATM_G",
# "0-D4(2B@G ATM@G)/aDZ": "adz_diff_d4_ATM_G",
# "0-D4(ATM TT ALL)/aDZ": "SAPT0_adz_3_IE_TT_ALL_d4_diff",
"SAPT(DFT)/aDZ": "SAPT_DFT_adz_3_IE_diff",
# "SAPT(DFT)-D4/aDZ": "SAPT_DFT_adz_3_IE_d4_diff",
"SAPT(DFT)/aTZ": "SAPT_DFT_atz_3_IE_diff",
},
"", # f"All Dimers (8299)",
# f"8299 Dimer Dataset",
f"{selected}_ATM_TOC",
bottom=0.45,
ylim=[-3, 3],
legend_loc="upper right",
transparent=True,
figure_size=(8, 2.0),

)
plot_violin_d3_d4_ALL(
df,
{
Expand Down Expand Up @@ -2095,6 +2125,248 @@ def plot_component_violin_zoomed(
plt.clf()
return

def plot_violin_d3_d4_ALL_zoomed_min_max_TOC(
df,
vals: {},
title_name: str,
pfn: str,
bottom: float = 0.4,
ylim=[-15, 35],
transparent=True,
widths=0.85,
figure_size=None,
set_xlable=False,
dpi=800,
pdf=False,
jpeg=True,
legend_loc="upper left",
) -> None:
print(f"Plotting {pfn}")
image_ext = "png"
if jpeg:
image_ext = "jpeg"

vLabels, vData = [], []
annotations = [] # [(x, y, text), ...]
cnt = 1
for k, v in vals.items():
df[v] = pd.to_numeric(df[v])
df_sub = df[df[v].notna()].copy()
vData.append(df_sub[v].to_list())
k_label = "\\textbf{" + k + "}"
# k_label = k
vLabels.append(k_label)
m = df_sub[v].max()
rmse = df_sub[v].apply(lambda x: x**2).mean() ** 0.5
mae = df_sub[v].apply(lambda x: abs(x)).mean()
max_pos_error = df_sub[v].apply(lambda x: x).max()
max_neg_error = df_sub[v].apply(lambda x: x).min()
text = r"\textit{%.2f}" % mae
text += "\n"
text += r"\textbf{%.2f}" % rmse
text += "\n"
text += r"\textrm{%.2f}" % max_pos_error
text += "\n"
text += r"\textrm{%.2f}" % max_neg_error
annotations.append((cnt, m, text))
cnt += 1

pd.set_option("display.max_columns", None)
# print(df[vals.values()].describe(include="all"))
# transparent figure
fig = plt.figure(dpi=dpi)
if figure_size is not None:
plt.figure(figsize=figure_size)
gs = gridspec.GridSpec(
1, 1, height_ratios=[1]
) # Adjust height ratios to change the size of subplots

# Create the main violin plot axis
ax = plt.subplot(gs[0]) # This will create the subplot for the main violin plot.
# ax = plt.subplot(111)
vplot = ax.violinplot(
vData,
showmeans=True,
showmedians=False,
showextrema=False,
# quantiles=[[0.05, 0.95] for i in range(len(vData))],
widths=widths,
)
# for partname in ('cbars', 'cmins', 'cmaxes', 'cmeans', 'cmedians'):
for n, partname in enumerate(["cmeans"]):
vp = vplot[partname]
vp.set_edgecolor("black")
vp.set_linewidth(1)
vp.set_alpha(1)
quantile_color = "red"
quantile_style = "-"
quantile_linewidth = 0.8
# for n, partname in enumerate(["cquantiles"]):
# vp = vplot[partname]
# vp.set_edgecolor(quantile_color)
# vp.set_linewidth(quantile_linewidth)
# vp.set_linestyle(quantile_style)
# vp.set_alpha(1)

colors = ["blue" if i % 2 == 0 else "green" for i in range(len(vLabels))]
# color_gt_olympic_teal = (0 /255, 140/255, 149/255) # Olympic teal
# color_gt_bold_blue = (58/255, 93/255, 174/255)
# colors = [color_gt_bold_blue if i % 2 == 0 else color_gt_olympic_teal for i in range(len(vLabels))]
for n, pc in enumerate(vplot["bodies"], 1):
pc.set_facecolor(colors[n - 1])
pc.set_alpha(0.6)
# pc.set_alpha(1)

vLabels.insert(0, "")
xs = [i for i in range(len(vLabels))]
xs_error = [i for i in range(-1, len(vLabels) + 1)]
ax.plot(
xs_error,
[1 for i in range(len(xs_error))],
"k--",
label=r"$\pm$1 $\mathrm{kcal\cdot mol^{-1}}$",
zorder=0,
linewidth=0.6,
)
ax.plot(
xs_error,
[0 for i in range(len(xs_error))],
"k--",
linewidth=0.5,
alpha=0.5,
# label=r"Reference Energy",
zorder=0,
)
ax.plot(
xs_error,
[-1 for i in range(len(xs_error))],
"k--",
zorder=0,
linewidth=0.6,
)
ax.plot(
[],
[],
linestyle=quantile_style,
color=quantile_color,
linewidth=quantile_linewidth,
label=r"5-95th Percentile",
)
# TODO: fix minor ticks to be between
ax.set_xticks(xs)
# minor_yticks = np.arange(ylim[0], ylim[1], 2)
# ax.set_yticks(minor_yticks, minor=True)

plt.setp(ax.set_xticklabels(vLabels), rotation=-45, fontsize="16", ha="left")
ax.set_xlim((0, len(vLabels)))
ax.set_ylim(ylim)

minor_yticks = create_minor_y_ticks(ylim)
ax.set_yticks(minor_yticks, minor=True)

# lg = ax.legend(loc=legend_loc, edgecolor="black", fontsize="9")
# lg.get_frame().set_alpha(None)
# lg.get_frame().set_facecolor((1, 1, 1, 0.0))

if set_xlable:
ax.set_xlabel("Level of Theory", color="k", fontsize="12")
# ax.set_ylabel(r"Error ($\mathrm{kcal\cdot mol^{-1}}$)", color="k", fontsize="14")
ax.set_ylabel(r"Error (kcal$\cdot$mol$^{-1}$)", color="k", fontsize="14")

ax.grid(color="#54585A", which="major", linewidth=0.5, alpha=0.5, axis="y")
ax.grid(color="#54585A", which="minor", linewidth=0.5, alpha=0.5)
# Annotations of RMSE

plt.setp(ax.set_xticklabels(vLabels), rotation=-45, fontsize="16", ha="left")
ax.set_xlim((0, len(vLabels)))
ax.set_ylim(ylim)

minor_yticks = create_minor_y_ticks(ylim)
ax.set_yticks(minor_yticks, minor=True)
# incread tick sizes
ax.tick_params(axis="both", which="major", labelsize=16)
# lg = ax.legend(loc=legend_loc, edgecolor="black", fontsize="9")
if set_xlable:
ax.set_xlabel("Level of Theory", color="k", fontsize="12")
ax.set_ylabel(r"Error (kcal$\cdot$mol$^{-1}$)", color="k", fontsize="14")
ax.grid(color="#54585A", which="major", linewidth=0.5, alpha=0.5, axis="y")
ax.grid(color="#54585A", which="minor", linewidth=0.5, alpha=0.5)

for n, xtick in enumerate(ax.get_xticklabels()):
xtick.set_color(colors[n - 1])
xtick.set_alpha(0.8)

# ax_error = plt.subplot(gs[0], sharex=ax)
# # ax_error.spines['top'].set_visible(False)
# ax_error.spines["right"].set_visible(False)
# ax_error.spines["left"].set_visible(False)
# ax_error.spines["bottom"].set_visible(False)
# ax_error.tick_params(left=False, labelleft=False, bottom=False, labelbottom=False)
#
# # Synchronize the x-limits with the main subplot
# ax_error.set_xlim((0, len(vLabels)))
# ax_error.set_ylim(0, 1) # Assuming the upper subplot should have no y range
# # Populate ax_error with error statistics through annotations
# # text = r"\textit{%.2f}" % mae
# # text += r"\textbf{%.2f}" % rmse
# # text += r"\textrm{%.2f}" % max_error
# error_labels = r"\textit{MAE}"
# error_labels += "\n"
# error_labels += r"\textbf{RMSE}"
# error_labels += "\n"
# error_labels += r"\textrm{MaxE}"
# error_labels += "\n"
# error_labels += r"\textrm{MinE}"
# ax_error.annotate(
# error_labels,
# xy=(0, 1), # Position at the vertical center of the narrow subplot
# xytext=(0, 0.2),
# color="black",
# fontsize="8",
# ha="center",
# va="center",
# )
# for idx, (x, y, text) in enumerate(annotations):
# ax_error.annotate(
# text,
# xy=(x, 1), # Position at the vertical center of the narrow subplot
# # xytext=(0, 0),
# xytext=(x, 0.2),
# color="black",
# fontsize="9",
# ha="center",
# va="center",
# )

if title_name is not None:
plt.title(f"{title_name}")
fig.subplots_adjust(bottom=bottom)

if pdf:
fn_pdf = f"plots/{pfn}_dbs_violin.pdf"
fn_png = f"plots/{pfn}_dbs_violin.png"
plt.savefig(
fn_pdf,
transparent=transparent,
bbox_inches="tight",
dpi=dpi,
)
if os.path.exists(fn_png):
os.system(f"rm {fn_png}")
os.system(f"pdftoppm -png -r 400 {fn_pdf} {fn_png}")
if os.path.exists(f"{fn_png}-1.png"):
os.system(f"mv {fn_png}-1.png {fn_png}")
else:
print(f"Error: {fn_png}-1.png does not exist")
else:
plt.savefig(
f"plots/{pfn}_dbs_violin.{image_ext}",
transparent=transparent,
bbox_inches="tight",
dpi=dpi,
)
plt.clf()
return

def plot_violin_d3_d4_ALL_zoomed_min_max(
df,
Expand Down

0 comments on commit c1da1b4

Please sign in to comment.