Skip to content

Commit

Permalink
Plot FM and ablation responses
Browse files Browse the repository at this point in the history
  • Loading branch information
phuongho43 committed Dec 10, 2024
1 parent 14dd334 commit 13b4c2d
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 35 deletions.
211 changes: 176 additions & 35 deletions protosignet/plot_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,9 @@
import seaborn as sns

from protosignet.model import sim_signet
from protosignet.style import RC_PARAMS
from protosignet.util import calc_n_nodes, eval_pareto, tag_objectives

CUSTOM_PALETTE = ["#648FFF", "#2ECC71", "#8069EC", "#EA822C", "#D143A4", "#F1C40F", "#34495E"]

CUSTOM_STYLE = {
"image.cmap": "turbo",
"figure.figsize": (24, 16),
"text.color": "#212121",
"axes.spines.top": False,
"axes.spines.right": False,
"axes.labelpad": 12,
"axes.labelcolor": "#212121",
"axes.labelweight": 600,
"axes.linewidth": 6,
"axes.edgecolor": "#212121",
"grid.linewidth": 1,
"xtick.major.pad": 12,
"ytick.major.pad": 12,
"lines.linewidth": 10,
"axes.labelsize": 72,
"xtick.labelsize": 56,
"ytick.labelsize": 56,
"legend.fontsize": 48,
}


def plot_figure_1d(data_dp, fig_fp):
"""Generate scatterplot of obj 1 (simplicity) vs obj 2 (performance) over all runs/repeats.
Expand All @@ -49,7 +27,7 @@ def plot_figure_1d(data_dp, fig_fp):
df_top["is_pareto"] = eval_pareto(df_top[["obj1", "obj2"]].to_numpy())
df_pareto = df_top.loc[df_top["is_pareto"] == 1]
print(df_pareto)
with plt.style.context(("seaborn-v0_8-whitegrid", CUSTOM_STYLE)):
with sns.axes_style("whitegrid"), mpl.rc_context(RC_PARAMS):
fig, ax = plt.subplots(figsize=(24, 20))
sns.scatterplot(data=df_gen_001, x="obj1", y="obj2", edgecolor="#212121", facecolor="#2ECC71", alpha=0.8, linewidth=2, s=600)
sns.scatterplot(data=df_gen_010, x="obj1", y="obj2", edgecolor="#212121", facecolor="#F1C40F", alpha=0.8, linewidth=2, s=600)
Expand Down Expand Up @@ -113,17 +91,17 @@ def plot_figure_1e(address, data_dp, fig_fp):
X1_df = pd.DataFrame({"t": tm, "y": Xm[0], "h": np.ones_like(tm) * 0})
X2_df = pd.DataFrame({"t": tm, "y": Xm[1], "h": np.ones_like(tm) * 1})
Xm_df = pd.concat([X1_df, X2_df], ignore_index=True)
with plt.style.context(("seaborn-v0_8-whitegrid", CUSTOM_STYLE)):
with sns.axes_style("whitegrid"), mpl.rc_context(RC_PARAMS):
fig, ax = plt.subplots(figsize=(24, 20))
sns.lineplot(data=Xm_df, x="t", y="y", hue="h", ax=ax, palette=["#8069EC", "#EA822C"], zorder=2.2)
ymin, ymax = ax.get_ylim()
sns.lineplot(data=Xm_df, x="t", y="y", hue="h", ax=ax, palette=["#8069EC", "#EA822C"], alpha=0.8, zorder=2.2)
for t in tu[uu > 0]:
ax.axvspan(t, t + 1, color="#648FFF", alpha=0.5, linewidth=0, zorder=2.1)
ax.set_ylim(ymin, ymax)
ax.axvspan(t - 1, t, color="#648FFF", alpha=0.5, linewidth=0, zorder=2.1)
ax.yaxis.set_ticks(np.arange(0, 1.1, 0.2))
ax.set_ylim(-0.1, 1.1)
handles = [
mpl.lines.Line2D([], [], color="#648FFF", linewidth=16, alpha=0.5),
mpl.lines.Line2D([], [], color="#8069EC", linewidth=16),
mpl.lines.Line2D([], [], color="#EA822C", linewidth=16),
mpl.lines.Line2D([], [], color="#8069EC", linewidth=16, alpha=1.0),
mpl.lines.Line2D([], [], color="#EA822C", linewidth=16, alpha=1.0),
]
group_labels = ["Input", "Dense Decoder", "Sparse Decoder"]
ax.legend(
Expand All @@ -135,12 +113,139 @@ def plot_figure_1e(address, data_dp, fig_fp):
shadow=False,
framealpha=1.0,
handletextpad=0.4,
borderpad=0.2,
borderpad=0.4,
labelspacing=0.2,
handlelength=1,
)
ax.set_xlabel("Time")
ax.set_ylabel("Output")
ax.locator_params(axis="x", nbins=10)
ax.locator_params(axis="y", nbins=10)
fig.tight_layout()
fig.canvas.draw()
fig.savefig(fig_fp, pad_inches=0.3, dpi=200, bbox_inches="tight", transparent=False)
plt.close("all")


def plot_figure_1f(address, data_dp, fig_fp):
"""Simulate FM response for a specified motif.
Args:
address (list): [rep_i, gen_j, pop_k]
data_dp (str): absolute path to data directory
fig_fp (str): absolute path for saving generated figure
"""
rep_i, gen_j, pop_k = address
df_rep = pd.read_csv(data_dp / f"{int(rep_i)}.csv")
pop_rep = df_rep["population"].values
pop_gen = np.array(ast.literal_eval(pop_rep[int(gen_j)]))
n_params = pop_gen.shape[1]
n_nodes = calc_n_nodes(n_params)
indiv = pop_gen[int(pop_k)].reshape(int(n_nodes), -1)
kr = indiv[:, 0]
ku = indiv[:, 1]
kX = indiv[:, 2:]
tu = np.arange(0, 121, 1)
periods = [1, 2, 3, 4, 5, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32, 36, 40, 50, 60, 70, 80, 90, 100, 121]
periods = np.array(periods)
freqs = 1 / periods
ave_denser = []
ave_sparser = []
for period in periods:
uu = np.zeros_like(tu)
uu[period:121:period] = 1
tm, Xm = sim_signet(tu, uu, kr, ku, kX)
ave_denser.append(np.mean(Xm[0]))
ave_sparser.append(np.mean(Xm[1]))
ave_denser = np.array(ave_denser)
ave_sparser = np.array(ave_sparser)
print(f"Dense Decoder FM Peak: {freqs[np.argmax(ave_denser)]} Hz")
print(f"Sparse Decoder FM Peak: {freqs[np.argmax(ave_sparser)]} Hz")
ave_denser_df = pd.DataFrame({"t": freqs, "y": ave_denser, "h": np.ones_like(freqs) * 0})
ave_sparser_df = pd.DataFrame({"t": freqs, "y": ave_sparser, "h": np.ones_like(freqs) * 1})
ave_df = pd.concat([ave_denser_df, ave_sparser_df], ignore_index=True)
with sns.axes_style("whitegrid"), mpl.rc_context(RC_PARAMS):
fig, ax = plt.subplots(figsize=(24, 20))
palette = ["#8069EC", "#EA822C"]
sns.lineplot(data=ave_df, x="t", y="y", hue="h", ax=ax, palette=palette)
ax.set_xlabel("FM Input (Hz)")
ax.set_ylabel("Mean Ouput (AU)")
ax.set_xscale("log")
group_labels = ["Dense Decoder", "Sparse Decoder"]
ax.yaxis.set_ticks(np.arange(0, 1.1, 0.2))
ax.set_ylim(-0.1, 1.1)
handles = [
mpl.lines.Line2D([], [], color="#8069EC", linewidth=16, alpha=1.0),
mpl.lines.Line2D([], [], color="#EA822C", linewidth=16, alpha=1.0),
]
ax.legend(
handles,
group_labels,
loc="best",
markerscale=4,
frameon=True,
shadow=False,
framealpha=1.0,
handletextpad=0.4,
borderpad=0.4,
labelspacing=0.2,
handlelength=1,
)
fig.tight_layout()
fig.canvas.draw()
fig.savefig(fig_fp, pad_inches=0.3, dpi=200, bbox_inches="tight", transparent=False)
plt.close("all")


def plot_figure_1g(k_params, fig_fp):
"""Study FM decoder motif with modified reactions.
Args:
k_params (1 x N+N+N*N array): parameters for simulating the signet model with N nodes
fig_fp (str): absolute path for saving generated figure
"""
k_params = np.array(k_params)
n_nodes = calc_n_nodes(len(k_params))
k_params = k_params.reshape(int(n_nodes), -1)
kr = k_params[:, 0]
ku = k_params[:, 1]
kX = k_params[:, 2:]
tu = np.arange(0, 121, 1.0)
uu = np.zeros_like(tu)
uu[40:80:10] = 1.0
uu[80:121:1] = 1.0
tm, Xm = sim_signet(tu, uu, kr, ku, kX)
X1_df = pd.DataFrame({"t": tm, "y": Xm[0], "h": np.ones_like(tm) * 0})
X2_df = pd.DataFrame({"t": tm, "y": Xm[1], "h": np.ones_like(tm) * 1})
Xm_df = pd.concat([X1_df, X2_df], ignore_index=True)
with sns.axes_style("whitegrid"), mpl.rc_context(RC_PARAMS):
fig, ax = plt.subplots(figsize=(24, 20))
sns.lineplot(data=Xm_df, x="t", y="y", hue="h", ax=ax, palette=["#8069EC", "#EA822C"], alpha=0.8, zorder=2.2)
for t in tu[uu > 0]:
ax.axvspan(t - 1, t, color="#648FFF", alpha=0.5, linewidth=0, zorder=2.1)
ax.yaxis.set_ticks(np.arange(0, 1.1, 0.2))
ax.set_ylim(-0.1, 1.1)
handles = [
mpl.lines.Line2D([], [], color="#648FFF", linewidth=16, alpha=0.5),
mpl.lines.Line2D([], [], color="#8069EC", linewidth=16, alpha=1.0),
mpl.lines.Line2D([], [], color="#EA822C", linewidth=16, alpha=1.0),
]
group_labels = ["Input", "Dense Decoder", "Sparse Decoder"]
ax.legend(
handles,
group_labels,
loc="best",
markerscale=4,
frameon=True,
shadow=False,
framealpha=1.0,
handletextpad=0.4,
borderpad=0.4,
labelspacing=0.2,
handlelength=1,
)
ax.set_xlabel("Time")
ax.set_ylabel("AU")
ax.set_ylabel("Output")
ax.locator_params(axis="x", nbins=10)
ax.locator_params(axis="y", nbins=10)
fig.tight_layout()
Expand All @@ -154,13 +259,49 @@ def main():
save_dp = Path("/home/phuong/data/protosignet/dual_fm/figs/")
save_dp.mkdir(parents=True, exist_ok=True)

# fig_fp = save_dp / "fig_1d.png"
# plot_figure_1d(data_dp, fig_fp)
fig_fp = save_dp / "fig_1d.png"
plot_figure_1d(data_dp, fig_fp)

for a, address in enumerate([[0, 241, 93], [1, 235, 83], [3, 246, 45]]):
fig_fp = save_dp / f"fig_1e_{a}.png"
plot_figure_1e(address, data_dp, fig_fp)

fig_fp = save_dp / "fig_1f.png"
address = [3, 246, 45]
plot_figure_1f(address, data_dp, fig_fp)

k_params_0 = [ # regular FM decoder motif
-1, 10, 0, 0, 0, 0, 0,
-0.1, 0.01, -10, 10, 0, 0, 0,
-10, 0, 0, 0, 0, 0, 0,
-1, 0, 0, 0, 0, 0, 0,
-10, 0, 0, 0, 0, 0, 0,
] # fmt: skip
k_params_1 = [ # weaker [X2] self-activation
-1, 10, 0, 0, 0, 0, 0,
-0.1, 0.01, -10, 1, 0, 0, 0,
-10, 0, 0, 0, 0, 0, 0,
-1, 0, 0, 0, 0, 0, 0,
-10, 0, 0, 0, 0, 0, 0,
] # fmt: skip
k_params_2 = [ # weaker [X1] repression towards [X2]
-1, 10, 0, 0, 0, 0, 0,
-0.1, 0.01, -1, 10, 0, 0, 0,
-10, 0, 0, 0, 0, 0, 0,
-1, 0, 0, 0, 0, 0, 0,
-10, 0, 0, 0, 0, 0, 0,
] # fmt: skip
k_params_3 = [ # stronger [X2] induction
-1, 10, 0, 0, 0, 0, 0,
-0.1, 10, -10, 10, 0, 0, 0,
-10, 0, 0, 0, 0, 0, 0,
-1, 0, 0, 0, 0, 0, 0,
-10, 0, 0, 0, 0, 0, 0,
] # fmt: skip
for i, kk in enumerate([k_params_0, k_params_1, k_params_2, k_params_3]):
fig_fp = save_dp / f"fig_1g_{i}.png"
plot_figure_1g(kk, fig_fp)


if __name__ == "__main__":
main()
33 changes: 33 additions & 0 deletions protosignet/style.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
PALETTE = ["#648FFF", "#2ECC71", "#8069EC", "#EA822C", "#D143A4", "#F1C40F", "#34495E"]

RC_PARAMS = {
"figure.figsize": (24, 16),
"lines.linewidth": 8,
"text.color": "#212121",
"axes.spines.top": False,
"axes.spines.right": False,
"axes.labelsize": 72,
"axes.labelpad": 12,
"axes.labelcolor": "#212121",
"axes.labelweight": 600,
"axes.linewidth": 6,
"axes.edgecolor": "#212121",
"xtick.bottom": True,
"ytick.left": True,
"xtick.major.pad": 12,
"ytick.major.pad": 12,
"xtick.labelsize": 56,
"ytick.labelsize": 56,
"xtick.color": "#212121",
"ytick.color": "#212121",
"xtick.major.size": 24,
"ytick.major.size": 24,
"xtick.major.width": 6,
"ytick.major.width": 6,
"xtick.minor.size": 12,
"ytick.minor.size": 12,
"xtick.minor.width": 2,
"ytick.minor.width": 2,
"legend.fontsize": 48,
"grid.linewidth": 1,
}
2 changes: 2 additions & 0 deletions protosignet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,6 @@ def calc_n_nodes(n_params):
while y != 0:
x += 1
y = x**2 + 2 * x - n_params
if x > 10:
raise ValueError("n_nodes")
return x

0 comments on commit 13b4c2d

Please sign in to comment.