Skip to content

Commit

Permalink
reconfigure digestion syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Apr 23, 2024
1 parent 1b80d29 commit 8accb61
Show file tree
Hide file tree
Showing 16 changed files with 182 additions and 147 deletions.
14 changes: 6 additions & 8 deletions docs/source/tutorials/hyperparameters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,11 @@
"metadata": {},
"outputs": [],
"source": [
"def digestion(db, uid):\n",
" products = db[uid].table()\n",
"def digestion(df):\n",
" for index, entry in df.iterrows():\n",
" df.loc[index, \"booth\"] = functions.booth(entry.x1, entry.x2)\n",
"\n",
" for index, entry in products.iterrows():\n",
" products.loc[index, \"booth\"] = functions.booth(entry.x1, entry.x2)\n",
"\n",
" return products"
" return df"
]
},
{
Expand Down Expand Up @@ -91,7 +89,7 @@
" db=db,\n",
")\n",
"\n",
"RE(agent.learn(acq_func=\"qr\", n=16))\n",
"RE(agent.learn(acqf=\"qr\", n=16))\n",
"\n",
"agent.plot_objectives()"
]
Expand All @@ -114,7 +112,7 @@
},
"outputs": [],
"source": [
"agent.plot_acquisition(acq_func=\"qei\")"
"agent.plot_acquisition(acqf=\"qei\")"
]
},
{
Expand Down
16 changes: 7 additions & 9 deletions docs/source/tutorials/introduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,11 @@
"metadata": {},
"outputs": [],
"source": [
"def digestion(db, uid):\n",
" products = db[uid].table()\n",
"def digestion(df):\n",
" for index, entry in df.iterrows():\n",
" df.loc[index, \"himmelblau\"] = functions.himmelblau(entry.x1, entry.x2)\n",
"\n",
" for index, entry in products.iterrows():\n",
" products.loc[index, \"himmelblau\"] = functions.himmelblau(entry.x1, entry.x2)\n",
"\n",
" return products"
" return df"
]
},
{
Expand Down Expand Up @@ -184,7 +182,7 @@
"metadata": {},
"outputs": [],
"source": [
"agent.all_acq_funcs"
"agent.all_acqfs"
]
},
{
Expand All @@ -203,7 +201,7 @@
"metadata": {},
"outputs": [],
"source": [
"agent.plot_acquisition(acq_func=\"qei\")"
"agent.plot_acquisition(acqf=\"qei\")"
]
},
{
Expand Down Expand Up @@ -246,7 +244,7 @@
"outputs": [],
"source": [
"res = agent.ask(\"qei\", n=8, route=True)\n",
"agent.plot_acquisition(acq_func=\"qei\")\n",
"agent.plot_acquisition(acqf=\"qei\")\n",
"plt.scatter(*res[\"points\"].T, marker=\"d\", facecolor=\"w\", edgecolor=\"k\")\n",
"plt.plot(\n",
" *res[\"points\"].T,\n",
Expand Down
16 changes: 7 additions & 9 deletions docs/source/tutorials/pareto-fronts.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,16 @@
"import numpy as np\n",
"\n",
"\n",
"def digestion(db, uid):\n",
" products = db[uid].table()\n",
"\n",
" for index, entry in products.iterrows():\n",
"def digestion(df):\n",
" for index, entry in df.iterrows():\n",
" x1, x2 = entry.x1, entry.x2\n",
"\n",
" products.loc[index, \"f1\"] = (x1 - 2) ** 2 + (x2 - 1) + 2\n",
" products.loc[index, \"f2\"] = 9 * x1 - (x2 - 1) + 2\n",
" products.loc[index, \"c1\"] = x1**2 + x2**2\n",
" products.loc[index, \"c2\"] = x1 - 3 * x2 + 10\n",
" df.loc[index, \"f1\"] = (x1 - 2) ** 2 + (x2 - 1) + 2\n",
" df.loc[index, \"f2\"] = 9 * x1 - (x2 - 1) + 2\n",
" df.loc[index, \"c1\"] = x1**2 + x2**2\n",
" df.loc[index, \"c2\"] = x1 - 3 * x2 + 10\n",
"\n",
" return products\n",
" return df\n",
"\n",
"\n",
"dofs = [\n",
Expand Down
18 changes: 10 additions & 8 deletions scripts/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@
err_cbar = obj_plt.fig.colorbar(mappable=im3, ax=[ax3], location="bottom", aspect=16)

for ax in [ax1, ax2, ax3]:
ax.set_xlabel(agent.dofs[0].label)
ax.set_ylabel(agent.dofs[1].label)
ax.set_xlabel(agent.dofs[0].label_with_units)
ax.set_ylabel(agent.dofs[1].label_with_units)


acqf_configs = {
Expand Down Expand Up @@ -103,8 +103,8 @@
acqf_plt_objs[acqf]["hist"] = ax.scatter([], [])
acqf_plt_objs[acqf]["best"] = ax.scatter([], [])

ax.set_xlabel(agent.dofs[0].label)
ax.set_ylabel(agent.dofs[1].label)
ax.set_xlabel(agent.dofs[0].label_with_units)
ax.set_ylabel(agent.dofs[1].label_with_units)


acqf_button_options = {index: config["name"] for index, config in acqf_configs.items()}
Expand Down Expand Up @@ -135,11 +135,12 @@ def learn():
with obj_plt:
obj = agent.objectives[0]

x_samples = agent.train_inputs().detach().numpy()
y_samples = agent.train_targets(obj.name).detach().numpy()[..., 0]
x_samples = agent.raw_inputs().detach().numpy()
y_samples = agent.raw_targets(obj.name).detach().numpy()[..., 0]

x = agent.sample(method="grid", n=20000) # (n, n, 1, d)
p = obj.model.posterior(x)
model_x = agent.dofs.transform(x)
p = obj.model.posterior(model_x)

m = p.mean.squeeze(-1, -2).detach().numpy()
e = p.variance.sqrt().squeeze(-1, -2).detach().numpy()
Expand All @@ -164,12 +165,13 @@ def learn():

with acq_plt:
x = agent.sample(method="grid", n=20000) # (n, n, 1, d)
model_x = agent.dofs.transform(x)
x_samples = agent.train_inputs().detach().numpy()

for acqf in acqf_plt_objs.keys():
ax = acqf_plt_objs[acqf]["ax"]

acqf_obj = getattr(agent, acqf)(x).detach().numpy()
acqf_obj = getattr(agent, acqf)(model_x).detach().numpy()

acqf_norm = mpl.colors.Normalize(vmin=np.nanmin(acqf_obj), vmax=np.nanmax(acqf_obj))
acqf_plt_objs[acqf]["im"].set_data(acqf_obj.T[::-1])
Expand Down
24 changes: 18 additions & 6 deletions src/blop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
dets: Sequence[Signal] = [],
acquistion_plan=default_acquisition_plan,
digestion: Callable = default_digestion_function,
digestion_kwargs: dict = {},
verbose: bool = False,
tolerate_acquisition_errors=False,
sample_center_on_init=False,
Expand All @@ -89,7 +90,9 @@ def __init__(
acquisition_plan : optional
A plan that samples the beamline for some given inputs.
digestion :
A function to digest the output of the acquisition, taking arguments (db, uid).
A function to digest the output of the acquisition, taking a DataFrame as an argument.
digestion_kwargs :
Some kwargs for the digestion function.
db : optional
A databroker instance.
verbose : bool
Expand Down Expand Up @@ -130,6 +133,7 @@ def __init__(
self.dets = dets
self.acquisition_plan = acquistion_plan
self.digestion = digestion
self.digestion_kwargs = digestion_kwargs

self.verbose = verbose

Expand Down Expand Up @@ -159,10 +163,17 @@ def __iter__(self):
def __getattr__(self, attr):
acqf_config = acquisition.parse_acqf_identifier(attr, strict=False)
if acqf_config is not None:
acqf, _ = _construct_acqf(acqf_name=acqf_config["name"])
acqf, _ = _construct_acqf(agent=self, acqf_name=acqf_config["name"])
return acqf
raise AttributeError(f"No attribute named '{attr}'.")

def refresh(self):
self._construct_all_models()
self._train_all_models()

def redigest(self):
self.table = self.digestion(self.table, **self.digestion_kwargs)

def sample(self, n: int = DEFAULT_MAX_SAMPLES, method: str = "quasi-random") -> torch.Tensor:
"""
Returns a (..., 1, n_active_dofs) tensor of points sampled within the parameter space.
Expand Down Expand Up @@ -272,7 +283,9 @@ def ask(self, acqf="qei", n=1, route=True, sequential=True, upsample=1, **acqf_k
duration = 1e3 * (ttime.monotonic() - start_time)

if route and n > 1:
routing_index = utils.route(self.dofs.subset(active=True, read_only=False).readback, points)
current_points = np.array([dof.readback for dof in active_dofs if not dof.read_only])
travel_expenses = np.array([dof.travel_expense for dof in active_dofs if not dof.read_only])
routing_index = utils.route(current_points, points, dim_weights=travel_expenses)
points = points[routing_index]

if upsample > 1:
Expand Down Expand Up @@ -479,8 +492,7 @@ def acquire(self, acquisition_inputs):
[*self.dets, *self.dofs.devices],
delay=self.trigger_delay,
)

products = self.digestion(self.db, uid)
products = self.digestion(self.db[uid].table(), **self.digestion_kwargs)

except KeyboardInterrupt as interrupt:
raise interrupt
Expand Down Expand Up @@ -549,7 +561,7 @@ def model(self):

def posterior(self, x):
"""A model encompassing all the objectives. A single GP in the single-objective case, or a model list."""
return self.model.posterior(torch.tensor(x))
return self.model.posterior(self.dofs.transform(torch.tensor(x)))

@property
def fitness_model(self):
Expand Down
18 changes: 10 additions & 8 deletions src/blop/bayesian/acquisition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
import os

import yaml
import pandas as pd
from botorch.utils.transforms import normalize
import yaml

from . import analytic, monte_carlo
from .analytic import * # noqa F401
from .monte_carlo import * # noqa F401

# from botorch.utils.transforms import normalize


here, this_filename = os.path.split(__file__)

with open(f"{here}/config.yml", "r") as f:
config = yaml.safe_load(f)


def all_acqfs(columns=["identifier", "type", "multitask_only", "description"]):
acqfs = pd.DataFrame(config).T[columns]
acqfs.index.name = "name"
return acqfs.sort_values(["type", "name"])


def parse_acqf_identifier(identifier, strict=True):
for acqf_name in config.keys():
if identifier.lower() in [acqf_name, config[acqf_name]["identifier"]]:
Expand All @@ -26,6 +30,7 @@ def parse_acqf_identifier(identifier, strict=True):
raise ValueError(f"'{identifier}' is not a valid acquisition function identifier.")
return None


def _construct_acqf(agent, acqf_name, **acqf_kwargs):
"""Generates an acquisition function from a supplied identifier. A list of acquisition functions and
their identifiers can be found at `agent.all_acqfs`.
Expand All @@ -38,7 +43,6 @@ def _construct_acqf(agent, acqf_name, **acqf_kwargs):

# there is probably a better way to structure this
if acqf_name == "expected_improvement":

acqf_kwargs["best_f"] = agent.best_f(weights="default")

acqf = analytic.ConstrainedLogExpectedImprovement(
Expand All @@ -49,7 +53,6 @@ def _construct_acqf(agent, acqf_name, **acqf_kwargs):
)

elif acqf_name == "monte_carlo_expected_improvement":

acqf_kwargs["best_f"] = agent.best_f(weights="default")

acqf = monte_carlo.qConstrainedExpectedImprovement(
Expand All @@ -60,7 +63,6 @@ def _construct_acqf(agent, acqf_name, **acqf_kwargs):
)

elif acqf_name == "probability_of_improvement":

acqf_kwargs["best_f"] = agent.best_f(weights="default")

acqf = analytic.ConstrainedLogProbabilityOfImprovement(
Expand All @@ -86,15 +88,15 @@ def _construct_acqf(agent, acqf_name, **acqf_kwargs):
)

elif acqf_name == "monte_carlo_noisy_expected_hypervolume_improvement":

acqf_kwargs["ref_point"] = acqf_kwargs.get("ref_point", agent.random_ref_point)

acqf = monte_carlo.qConstrainedNoisyExpectedHypervolumeImprovement(
constraint=agent.constraint,
model=agent.fitness_model,
X_baseline=agent.input_normalization.forward(agent.train_inputs()),
# X_baseline=agent.input_normalization.forward(agent.train_inputs())[],
X_baseline=agent.dofs.transform(agent.train_inputs(active=True)),
prune_baseline=True,
**acqf_kwargs
**acqf_kwargs,
)

elif acqf_name == "upper_confidence_bound":
Expand Down
8 changes: 5 additions & 3 deletions src/blop/digestion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
def default_digestion_function(db, uid):
products = db[uid].table(fill=True)
return products
import pandas as pd


def default_digestion_function(df: pd.DataFrame) -> pd.DataFrame:
return df
27 changes: 9 additions & 18 deletions src/blop/dofs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
"description": "str",
"readback": "object",
"type": "str",
"units": "str",
"tags": "object",
"transform": "str",
"search_domain": "object",
"trust_domain": "object",
"domain": "object",
"active": "bool",
"read_only": "bool",
"units": "str",
"tags": "object",
}

DOF_TYPES = ["continuous", "binary", "ordinal", "categorical"]
Expand Down Expand Up @@ -125,14 +125,15 @@ class DOF:
name: str = None
description: str = ""
type: str = None
transform: str = None
search_domain: Union[Tuple[float, float], Sequence] = None
trust_domain: Union[Tuple[float, float], Sequence] = None
units: str = None
read_only: bool = False
active: bool = True
transform: str = None
read_only: bool = False
tags: list = field(default_factory=list)
device: Signal = None
travel_expense: float = 1

def __repr__(self):
nodef_f_vals = ((f.name, attrgetter(f.name)(self)) for f in fields(self))
Expand Down Expand Up @@ -220,18 +221,6 @@ def __post_init__(self):
# all dof degrees of freedom are hinted
self.device.kind = "hinted"

@property
def domain(self):
"""
The total domain of the DOF.
"""
if self.transform is None:
if self.type == "continuous":
return (-np.inf, np.inf)
else:
return self.search_domain
return SUPPORTED_DOF_TRANSFORMS[self.transform]

@property
def _search_domain(self):
"""
Expand Down Expand Up @@ -363,8 +352,10 @@ def __call__(self, *args, **kwargs):

def __getattr__(self, attr):
# This is called if we can't find the attribute in the normal way.
if attr in DOF_FIELD_TYPES.keys():
return np.array([getattr(dof, attr) for dof in self.dofs])
if all([hasattr(dof, attr) for dof in self.dofs]):
if DOF_FIELD_TYPES.get(attr) in ["float", "int", "bool"]:
return np.array([getattr(dof, attr) for dof in self.dofs])
return [getattr(dof, attr) for dof in self.dofs]
if attr in self.names:
return self.__getitem__(attr)

Expand Down
Loading

0 comments on commit 8accb61

Please sign in to comment.