Skip to content

Commit

Permalink
update plotting scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc committed Jul 30, 2024
1 parent fb0fea0 commit 1551e7f
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 50 deletions.
10 changes: 9 additions & 1 deletion src/scripts/plots/sos/complex_squared_npcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,15 @@ def format_dataset(d: str) -> str:

setup_tueplots(num_rows, num_cols, rel_width=0.4, hw_ratio=0.8)
fig, ax = plt.subplots(num_rows, num_cols, squeeze=True, sharey=True)
g = sb.boxplot(df, x="model_id", y=metric, hue="model_id", ax=ax)
g = sb.boxplot(
df,
x="model_id",
y=metric,
hue="model_id",
width=0.7,
fliersize=3.0,
ax=ax
)
ax.set_xlabel("")
if args.ylabel:
ax.set_ylabel(format_metric(args.metric, train=args.train))
Expand Down
10 changes: 8 additions & 2 deletions src/scripts/plots/sos/complex_squared_npcs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ TBOARD_PATH="${TBOARD_PATH:-tboard-runs/complex-squared-npcs}"
for dataset in power gas hepmass miniboone
do
echo "Processing results relative to data set $dataset"
python -m "$PYSCRIPT" "$TBOARD_PATH" "$dataset" --train --ylabel &
python -m "$PYSCRIPT" "$TBOARD_PATH" "$dataset" --ylabel
if [ "$dataset" == "power" ]
then
OTHER_FLAGS="--ylabel"
else
OTHER_FLAGS=""
fi
python -m "$PYSCRIPT" "$TBOARD_PATH" "$dataset" $OTHER_FLAGS --train &
python -m "$PYSCRIPT" "$TBOARD_PATH" "$dataset" $OTHER_FLAGS
done
104 changes: 57 additions & 47 deletions src/scripts/plots/sos/num_of_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,35 @@
parser.add_argument("tboard_path", type=str, help="The Tensorboard runs path")
parser.add_argument("dataset", type=str, help="Dataset name")
parser.add_argument("--metric", default="avg_ll", help="The metric to plot")
parser.add_argument("--models", default="MonotonicPC;BornPC", help="The models")
# parser.add_argument("--xlabels", default="Number of components;Number of squares", help="The x-axis labels for each model")
parser.add_argument("--models", default="MPC;SOS", help="The models")
parser.add_argument(
"--train",
action="store_true",
default=False,
help="Whether to show the metric on the training data",
)
parser.add_argument(
"--ylabel",
action="store_true",
default=False,
help="Whether to show the y-axis label",
)


def format_metric(m: str) -> str:
def format_metric(m: str, train: bool = False) -> str:
if m == "avg_ll":
return "Average LL"
m = "Average LL"
elif m == "bpd":
return "Bits per dimension"
m = "Bits per dimension"
elif m == "ppl":
return "Perplexity"
assert False
m = "Perplexity"
else:
assert False
if train:
m = f"{m} [train]"
else:
m = f"{m} [test]"
return m


def filter_dataframe(df: pd.DataFrame, filter_dict: dict) -> pd.DataFrame:
Expand All @@ -38,11 +55,14 @@ def filter_dataframe(df: pd.DataFrame, filter_dict: dict) -> pd.DataFrame:
return df


def format_model(m: str) -> str:
if m == "MonotonicPC":
def format_model(m: str, exp_alias: str) -> str:
if m == "MPC":
return r"$+_{\mathsf{sd}}$"
elif m == "BornPC":
return r"$\Sigma_{\mathsf{sd}}^2$"
elif m == "SOS":
if exp_alias == "real":
return r"$\pm^2 (\mathbb{R})$"
elif exp_alias == "complex":
return r"$\pm^2 (\mathbb{C})$"
assert False


Expand All @@ -58,14 +78,19 @@ def format_dataset(d: str) -> str:

if __name__ == "__main__":
args = parser.parse_args()
metric = "Best/Test/" + args.metric
metric = (
("Best/Train/" + args.metric) if args.train else ("Best/Test/" + args.metric)
)
models = args.models.split(";")
df = retrieve_tboard_runs(args.tboard_path, metric)
df = retrieve_tboard_runs(os.path.join(args.tboard_path, args.dataset), metric)
df = df[df["dataset"] == args.dataset]
df = df[df["model"].isin(models)]
df = df.sort_values("model", ascending=False)
df["model"] = df["model"].apply(format_model)
df["num_replicas"] = df["num_replicas"].astype(int)
df = df[df["exp_alias"].isin(['', 'real'])]
df = df.sort_values("model", ascending=True)
df["model_id"] = df.apply(
lambda row: format_model(row.model, row.exp_alias), axis=1
)
df["num_components"] = df["num_components"].astype(int)
num_sum_parameters = df["num_sum_params"].tolist()
rel_num_sum_parameters = (
np.max(num_sum_parameters) - np.min(num_sum_parameters)
Expand All @@ -74,43 +99,28 @@ def format_dataset(d: str) -> str:
num_rows = 1
num_cols = 1

setup_tueplots(num_rows, num_cols, rel_width=0.415, hw_ratio=0.8)
setup_tueplots(num_rows, num_cols, rel_width=0.4, hw_ratio=0.8)
fig, ax = plt.subplots(num_rows, num_cols, squeeze=True, sharey=True)
g = sb.swarmplot(
df,
x="num_replicas",
y=metric,
hue="model",
ax=ax,
dodge=True,
alpha=0.55,
marker="x",
linewidth=1,
legend="brief",
)
g.get_legend().set_title(None)
sb.boxplot(
g = sb.boxplot(
df,
x="num_replicas",
x="num_components",
y=metric,
hue="model",
hue="model_id",
width=0.5,
fliersize=2.0,
ax=ax,
dodge=False,
fill=False,
gap=0.25,
whiskerprops={"visible": False},
showfliers=False,
showbox=False,
showcaps=False,
legend=False,
zorder=999,
legend=False
)
sb.move_legend(ax, handlelength=1.0, handletextpad=0.5, loc="best")
ax.set_xlabel(r"Num. of $\mathrm{MPCs}$ / $\mathrm{NPC}^2\mathrm{s}$", fontsize=12)
ax.set_ylabel(format_metric(args.metric), fontsize=12)
#sb.move_legend(ax, handlelength=1.0, handletextpad=0.5, loc="best")
ax.set_xlabel(r"Num. of components")
if args.ylabel:
ax.set_ylabel(format_metric(args.metric, train=args.train))
else:
ax.set_ylabel("")
ax.tick_params(axis="both", which="major", labelsize=10)
ax.set_title(format_dataset(args.dataset), fontsize=12)
ax.set_title(format_dataset(args.dataset))

path = os.path.join("figures", "num-of-squares")
os.makedirs(path, exist_ok=True)
plt.savefig(os.path.join(path, f"{args.dataset}.pdf"))
filename = f"{args.dataset}-train.pdf" if args.train else f"{args.dataset}-test.pdf"
plt.savefig(os.path.join(path, filename))
17 changes: 17 additions & 0 deletions src/scripts/plots/sos/num_of_squares.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash

PYSCRIPT="scripts.plots.sos.num_of_squares"
TBOARD_PATH="${TBOARD_PATH:-tboard-runs/num-of-squares-1-to-n}"

for dataset in power gas hepmass miniboone
do
echo "Processing results relative to data set $dataset"
if [ "$dataset" == "power" ]
then
OTHER_FLAGS="--ylabel"
else
OTHER_FLAGS=""
fi
python -m "$PYSCRIPT" "$TBOARD_PATH" "$dataset" $OTHER_FLAGS --train &
python -m "$PYSCRIPT" "$TBOARD_PATH" "$dataset" $OTHER_FLAGS
done

0 comments on commit 1551e7f

Please sign in to comment.