Skip to content

Commit

Permalink
Update figure gen code
Browse files Browse the repository at this point in the history
  • Loading branch information
emdupre committed May 9, 2024
1 parent 3b1e569 commit 21a781f
Showing 1 changed file with 25 additions and 64 deletions.
89 changes: 25 additions & 64 deletions experiments/sigma_sweeps.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import partial
from itertools import product

import click
Expand All @@ -8,7 +7,7 @@
import matplotlib.transforms as mtransforms
from tensorflow_probability.substrates import jax as tfp

from bayes_ca.prox_grad import pgd, pgd_jaxopt
from bayes_ca.prox_grad import pgd_jaxopt

tfd = tfp.distributions

Expand Down Expand Up @@ -37,7 +36,6 @@ def stagger_data(gap, num_timesteps, num_features):
return subj_means, x0


# @partial(jit, static_argnums=(1, 2))
def sample_mu0(params, num_timesteps, num_features, mu_pri, sigmasq_pri, hazard_rates):
"""
For provided params, generate sample data and perform PGD to sample $\mu_0$.
Expand All @@ -48,7 +46,6 @@ def sample_mu0(params, num_timesteps, num_features, mu_pri, sigmasq_pri, hazard_
return results.params


# @partial(jit, static_argnums=(2, 3))
def sample_mu0_true_x0(
params, x0, num_timesteps, num_features, mu_pri, sigmasq_pri, hazard_rates
):
Expand All @@ -72,31 +69,21 @@ def plot_mu0s(
sigma_val,
n_samples,
):
"""
Currently only supports "average" x0_strategy, rather than "true" x0.
"""
""" """
gaps = jnp.linspace(0, max_gap, n_samples)
params = jnp.asarray(list(product([sigma_val**2], gaps)))

mu0s = []
means, x0s = vmap(stagger_data, in_axes=(0, None, None))(gaps, num_timesteps, num_features)

# if x0_strategy == "true":
# # the true changepoint
# x0 = jnp.concatenate(
# (
# -1 * jnp.ones((num_timesteps // 2, num_features)),
# jnp.ones((num_timesteps // 2, num_features)),
# )
# )

for m, x0 in zip(means, x0s):
results = pgd(x0, m, mu_pri, sigma_pri**2, sigma_val**2, hazard_rates)
mu0s.append(results.x)
mu0s = jit(
vmap(sample_mu0, in_axes=(0, None, None, None, None, None)), static_argnums=(1, 2)
)(params, num_timesteps, num_features, mu_pri, sigma_pri**2, hazard_rates)

ax.set_title(f"sampled $\mu_0$ at $\sigma_{{subj}}$ = {sigma_val}")
ax.set_title(f"Sampled $\mu_0$ at $\sigma^2_{{subj}}$ = {sigma_val**2}")
colors = plt.cm.viridis(jnp.linspace(0, 1, n_samples))
for i, mu0 in enumerate(mu0s):
p = ax.plot(mu0, c=colors[i], alpha=0.8, label=f"sampled $\mu_0$, {gaps[i]} stagger")
p = ax.plot(mu0, c=colors[i], alpha=0.8, label=f"Sampled $\mu_0$, {gaps[i]} stagger")
ax.set_xlabel("Time", labelpad=10)
ax.set_yticks([-1, 0, 1])
ax.spines[["right", "top"]].set_visible(False)

return ax

Expand All @@ -113,33 +100,11 @@ def plot_param_sweep(
max_sigmasq,
n_samples,
):
"""
Currently prefers JAXOpt over COPT implementation.
"""
""" """
gaps = jnp.linspace(1, max_gap, n_samples)
sigmasqs = jnp.linspace(0.01, max_sigmasq, n_samples)
sigmas = [jnp.sqrt(s) for s in sigmasqs]

params = jnp.asarray(list(product(sigmasqs, gaps)))

# COPT, true i
# mu0s = []
# means, _ = vmap(stagger_data, in_axes=(0, None, None))(gaps, num_timesteps, num_features)

# # the true changepoint
# x0 = jnp.concatenate(
# (
# -1 * jnp.ones((num_timesteps // 2, num_features)),
# jnp.ones((num_timesteps // 2, num_features)),
# )
# )

# for sigmasq in sigmasqs:
# for mean in means:
# result = pgd(x0, mean, mu_pri, sigma_pri**2, sigmasq, hazard_rates)
# mu0s.append(result.x)

# JAXOpt
mu0s = jit(
vmap(sample_mu0, in_axes=(0, None, None, None, None, None)), static_argnums=(1, 2)
)(params, num_timesteps, num_features, mu_pri, sigma_pri**2, hazard_rates)
Expand All @@ -152,7 +117,7 @@ def count_changepoints(mu0):
count_cp = jnp.asarray([count_changepoints(mu0) for mu0 in mu0s])
sigma_by_gap = jnp.reshape(count_cp, (n_samples, n_samples))

# check, for each sigma value, whether we still have 2 changepoints when
# check, for each sigma value, whether we still have 2 states when
# increasing stagger distance...
diff_dist = sigma_by_gap[:, :-1] != sigma_by_gap[:, 1:]
diff_dist = jnp.insert(diff_dist, 0, False, axis=1)
Expand All @@ -164,16 +129,14 @@ def count_changepoints(mu0):
hazard_prob = hazard_rates[0]
beta = -jnp.log(hazard_prob / (1 - hazard_prob))

ax.plot(sigmas, gap_threshold, c="#bc3978")
ax.plot(sigmasqs, [(beta * sigmasq) for sigmasq in sigmasqs], c="#fa7f5e")
ax.yaxis.tick_right()
ax.yaxis.set_label_position("right")
ax.axvline(x=vline, color="black", linestyle=(5, (10, 3)), linewidth=0.75)
ax.plot(sigmasqs, gap_threshold, c="#bc3978")
ax.plot(sigmasqs, [(beta * sigmasq) for sigmasq in sigmasqs], c="#52abef")
ax.axvline(x=(vline**2), color="black", linestyle=(5, (10, 3)), linewidth=0.75)

ax.set_xlabel("$\sigma_{{subj}}$ value", labelpad=10)
ax.set_xlabel("$\sigma^2_{{subj}}$ value", labelpad=10)
ax.set_ylabel("Stagger distance", labelpad=15)
ax.spines[["left", "top"]].set_visible(False)
ax.set_title(f"Transition from 1 to 2 changepoints")
ax.spines[["right", "top"]].set_visible(False)
ax.set_title(f"Transition to 2 changepoints")

return ax

Expand All @@ -193,9 +156,8 @@ def main(mu_pri, sigma_pri, sigma, hazard_prob, num_features, num_timesteps, x0_
hazard_rates = hazard_prob * jnp.ones(max_duration)
hazard_rates = hazard_rates.at[-1].set(1.0)

fig, axs = plt.subplot_mosaic(
[["a)"], ["b)"]], layout="constrained", sharex=True, figsize=(8, 6)
)
fig, axs = plt.subplot_mosaic([["A", "B"]], layout="constrained", figsize=(9, 4), dpi=300)

for label, ax in axs.items():
# label physical distance to the left and up:
trans = mtransforms.ScaledTranslation(-40 / 72, 7 / 72, fig.dpi_scale_trans)
Expand All @@ -211,7 +173,7 @@ def main(mu_pri, sigma_pri, sigma, hazard_prob, num_features, num_timesteps, x0_
)

panel_a = plot_param_sweep(
axs["a)"],
axs["A"],
sigma,
mu_pri,
sigma_pri,
Expand All @@ -222,9 +184,8 @@ def main(mu_pri, sigma_pri, sigma, hazard_prob, num_features, num_timesteps, x0_
max_sigmasq=9.0,
n_samples=50,
)

panel_b = plot_mu0s(
axs["b)"],
axs["B"],
mu_pri,
sigma_pri,
num_timesteps,
Expand All @@ -236,10 +197,10 @@ def main(mu_pri, sigma_pri, sigma, hazard_prob, num_features, num_timesteps, x0_
)

sm = plt.cm.ScalarMappable(cmap=plt.cm.viridis)
cbar = fig.colorbar(sm, ax=ax[1], location="right")
cbar = fig.colorbar(sm, ax=axs["B"], location="right")
cbar.set_ticks(ticks=[0, 0.5, 1], labels=[0, 50 // 2, 50])
cbar.ax.get_yaxis().labelpad = 15
cbar.ax.set_ylabel("stagger distance", rotation=270)
cbar.ax.set_ylabel("Stagger distance", rotation=270)

plt.show()

Expand Down

0 comments on commit 21a781f

Please sign in to comment.