Skip to content

Commit

Permalink
Merge pull request #68 from thomaswmorris/acquisition
Browse files Browse the repository at this point in the history
Better acquisition functions
  • Loading branch information
thomaswmorris authored May 10, 2024
2 parents be5b001 + 0dcbfb4 commit a4fa614
Show file tree
Hide file tree
Showing 36 changed files with 708 additions and 946 deletions.
14 changes: 6 additions & 8 deletions docs/source/tutorials/hyperparameters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,11 @@
"metadata": {},
"outputs": [],
"source": [
"def digestion(db, uid):\n",
" products = db[uid].table()\n",
"def digestion(df):\n",
" for index, entry in df.iterrows():\n",
" df.loc[index, \"booth\"] = functions.booth(entry.x1, entry.x2)\n",
"\n",
" for index, entry in products.iterrows():\n",
" products.loc[index, \"booth\"] = functions.booth(entry.x1, entry.x2)\n",
"\n",
" return products"
" return df"
]
},
{
Expand Down Expand Up @@ -91,7 +89,7 @@
" db=db,\n",
")\n",
"\n",
"RE(agent.learn(acq_func=\"qr\", n=16))\n",
"RE(agent.learn(acqf=\"qr\", n=16))\n",
"\n",
"agent.plot_objectives()"
]
Expand All @@ -114,7 +112,7 @@
},
"outputs": [],
"source": [
"agent.plot_acquisition(acq_func=\"qei\")"
"agent.plot_acquisition(acqf=\"qei\")"
]
},
{
Expand Down
23 changes: 9 additions & 14 deletions docs/source/tutorials/introduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,11 @@
"metadata": {},
"outputs": [],
"source": [
"def digestion(db, uid):\n",
" products = db[uid].table()\n",
"def digestion(df):\n",
" for index, entry in df.iterrows():\n",
" df.loc[index, \"himmelblau\"] = functions.himmelblau(entry.x1, entry.x2)\n",
"\n",
" for index, entry in products.iterrows():\n",
" products.loc[index, \"himmelblau\"] = functions.himmelblau(entry.x1, entry.x2)\n",
"\n",
" return products"
" return df"
]
},
{
Expand Down Expand Up @@ -184,7 +182,7 @@
"metadata": {},
"outputs": [],
"source": [
"agent.all_acq_funcs"
"agent.all_acqfs"
]
},
{
Expand All @@ -203,7 +201,7 @@
"metadata": {},
"outputs": [],
"source": [
"agent.plot_acquisition(acq_func=\"qei\")"
"agent.plot_acquisition(acqf=\"qei\")"
]
},
{
Expand Down Expand Up @@ -246,12 +244,9 @@
"outputs": [],
"source": [
"res = agent.ask(\"qei\", n=8, route=True)\n",
"agent.plot_acquisition(acq_func=\"qei\")\n",
"plt.scatter(*res[\"points\"].T, marker=\"d\", facecolor=\"w\", edgecolor=\"k\")\n",
"plt.plot(\n",
" *res[\"points\"].T,\n",
" color=\"r\",\n",
")"
"agent.plot_acquisition(acqf=\"qei\")\n",
"plt.scatter(res[\"points\"][\"x1\"], res[\"points\"][\"x2\"], marker=\"d\", facecolor=\"w\", edgecolor=\"k\")\n",
"plt.plot(res[\"points\"][\"x1\"], res[\"points\"][\"x2\"], color=\"r\")"
]
},
{
Expand Down
16 changes: 7 additions & 9 deletions docs/source/tutorials/pareto-fronts.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,16 @@
"import numpy as np\n",
"\n",
"\n",
"def digestion(db, uid):\n",
" products = db[uid].table()\n",
"\n",
" for index, entry in products.iterrows():\n",
"def digestion(df):\n",
" for index, entry in df.iterrows():\n",
" x1, x2 = entry.x1, entry.x2\n",
"\n",
" products.loc[index, \"f1\"] = (x1 - 2) ** 2 + (x2 - 1) + 2\n",
" products.loc[index, \"f2\"] = 9 * x1 - (x2 - 1) + 2\n",
" products.loc[index, \"c1\"] = x1**2 + x2**2\n",
" products.loc[index, \"c2\"] = x1 - 3 * x2 + 10\n",
" df.loc[index, \"f1\"] = (x1 - 2) ** 2 + (x2 - 1) + 2\n",
" df.loc[index, \"f2\"] = 9 * x1 - (x2 - 1) + 2\n",
" df.loc[index, \"c1\"] = x1**2 + x2**2\n",
" df.loc[index, \"c2\"] = x1 - 3 * x2 + 10\n",
"\n",
" return products\n",
" return df\n",
"\n",
"\n",
"dofs = [\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/passive-dofs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"metadata": {},
"outputs": [],
"source": [
"from blop.utils import functions\n",
"from blop.digestion.tests import constrained_himmelblau_digestion\n",
"from blop import DOF, Agent, Objective\n",
"from blop.dofs import BrownianMotion\n",
"\n",
Expand All @@ -58,7 +58,7 @@
"agent = Agent(\n",
" dofs=dofs,\n",
" objectives=objectives,\n",
" digestion=functions.constrained_himmelblau_digestion,\n",
" digestion=constrained_himmelblau_digestion,\n",
" db=db,\n",
" verbose=True,\n",
" tolerate_acquisition_errors=False,\n",
Expand Down
18 changes: 10 additions & 8 deletions scripts/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@
err_cbar = obj_plt.fig.colorbar(mappable=im3, ax=[ax3], location="bottom", aspect=16)

for ax in [ax1, ax2, ax3]:
ax.set_xlabel(agent.dofs[0].label)
ax.set_ylabel(agent.dofs[1].label)
ax.set_xlabel(agent.dofs[0].label_with_units)
ax.set_ylabel(agent.dofs[1].label_with_units)


acqf_configs = {
Expand Down Expand Up @@ -103,8 +103,8 @@
acqf_plt_objs[acqf]["hist"] = ax.scatter([], [])
acqf_plt_objs[acqf]["best"] = ax.scatter([], [])

ax.set_xlabel(agent.dofs[0].label)
ax.set_ylabel(agent.dofs[1].label)
ax.set_xlabel(agent.dofs[0].label_with_units)
ax.set_ylabel(agent.dofs[1].label_with_units)


acqf_button_options = {index: config["name"] for index, config in acqf_configs.items()}
Expand Down Expand Up @@ -135,11 +135,12 @@ def learn():
with obj_plt:
obj = agent.objectives[0]

x_samples = agent.train_inputs().detach().numpy()
y_samples = agent.train_targets(obj.name).detach().numpy()[..., 0]
x_samples = agent.raw_inputs().detach().numpy()
y_samples = agent.raw_targets(obj.name).detach().numpy()[..., 0]

x = agent.sample(method="grid", n=20000) # (n, n, 1, d)
p = obj.model.posterior(x)
model_x = agent.dofs.transform(x)
p = obj.model.posterior(model_x)

m = p.mean.squeeze(-1, -2).detach().numpy()
e = p.variance.sqrt().squeeze(-1, -2).detach().numpy()
Expand All @@ -164,12 +165,13 @@ def learn():

with acq_plt:
x = agent.sample(method="grid", n=20000) # (n, n, 1, d)
model_x = agent.dofs.transform(x)
x_samples = agent.train_inputs().detach().numpy()

for acqf in acqf_plt_objs.keys():
ax = acqf_plt_objs[acqf]["ax"]

acqf_obj = getattr(agent, acqf)(x).detach().numpy()
acqf_obj = getattr(agent, acqf)(model_x).detach().numpy()

acqf_norm = mpl.colors.Normalize(vmin=np.nanmin(acqf_obj), vmax=np.nanmax(acqf_obj))
acqf_plt_objs[acqf]["im"].set_data(acqf_obj.T[::-1])
Expand Down
Loading

0 comments on commit a4fa614

Please sign in to comment.