Skip to content

Commit

Permalink
make objectives toggleable
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Nov 10, 2023
1 parent 6576bc8 commit 34323e7
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 17 deletions.
2 changes: 1 addition & 1 deletion bloptools/bayesian/acquisition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_acquisition_function(agent, identifier="qei", return_metadata=True, verb
acq_func_name = parse_acq_func_identifier(identifier)
acq_func_config = config["upper_confidence_bound"]

if config[acq_func_name]["multitask_only"] and (agent.num_tasks == 1):
if config[acq_func_name]["multitask_only"] and (len(agent.objectives) == 1):
raise ValueError(f'Acquisition function "{acq_func_name}" is only for multi-task optimization problems!')

# there is probably a better way to structure this
Expand Down
20 changes: 11 additions & 9 deletions bloptools/bayesian/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def __init__(
self.initialized = False
self.a_priori_hypers = None

self.n_last_trained = 0

def tell(self, x: Mapping, y: Mapping, metadata=None, append=True, train_models=True, hypers=None):
"""
Inform the agent about new inputs and targets for the model.
Expand Down Expand Up @@ -147,10 +149,10 @@ def tell(self, x: Mapping, y: Mapping, metadata=None, append=True, train_models=

# TODO: should be a check per model
if len(self.table) > 2:
# if n_before_tell % self.train_every != n_after_tell % self.train_every:
self._update_models(train=train_models, a_priori_hypers=hypers)
if int(self.n_last_trained / self.train_every) != int(len(self.table) / self.train_every):
self._construct_models(train=train_models, a_priori_hypers=hypers)

def _update_models(self, train=True, skew_dims=None, a_priori_hypers=None):
def _construct_models(self, train=True, skew_dims=None, a_priori_hypers=None):
skew_dims = skew_dims if skew_dims is not None else self.latent_dim_tuples

inputs = self.table.loc[:, self.dofs.subset(active=True).names].values.astype(float)
Expand All @@ -173,7 +175,7 @@ def _update_models(self, train=True, skew_dims=None, a_priori_hypers=None):

likelihood = gpytorch.likelihoods.GaussianLikelihood(
noise_constraint=gpytorch.constraints.Interval(
torch.tensor(1e-2).square(),
torch.tensor(1e-4).square(),
torch.tensor(1 / obj.min_snr).square(),
),
# noise_prior=gpytorch.priors.torch_priors.LogNormalPrior(loc=loc, scale=scale),
Expand Down Expand Up @@ -293,14 +295,12 @@ def ask(self, acq_func_identifier="qei", n=1, route=True, sequential=True):
raise ValueError()

# define dummy acqf objective
acqf_obj = None
acqf_obj = 0

acq_func_meta["duration"] = duration = ttime.monotonic() - start_time

if self.verbose:
print(
f"found points {acq_points} with acqf {acq_func_meta['name']} in {duration:.01f} seconds (obj = {acqf_obj})"
)
print(f"found points {acq_points} in {1e3*duration:.01f} ms (obj = {acqf_obj})")

if route and n > 1:
routing_index = utils.route(self.dofs.subset(active=True, read_only=False).readback, acq_points)
Expand Down Expand Up @@ -619,7 +619,7 @@ def forget(self, index, train=True):
Make the agent forget some index of the data table.
"""
self.table.drop(index=index, inplace=True)
self._update_models(train=train)
self._construct_models(train=train)

def forget_last_n(self, n, train=True):
"""
Expand Down Expand Up @@ -687,6 +687,8 @@ def _train_models(self, **kwargs):
if self.verbose:
print(f"trained models in {ttime.monotonic() - t0:.01f} seconds")

self.n_last_trained = len(self.table)

@property
def all_acq_funcs(self):
"""Description and identifiers for all supported acquisition functions."""
Expand Down
6 changes: 3 additions & 3 deletions bloptools/bayesian/dofs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ophyd import Signal, SignalRO

DEFAULT_BOUNDS = (-5.0, +5.0)
DOF_FIELDS = ["name", "description", "readback", "lower_limit", "upper_limit", "units", "active", "read_only", "tags"]
DOF_FIELDS = ["description", "readback", "lower_limit", "upper_limit", "units", "active", "read_only", "tags"]

numeric = Union[float, int]

Expand Down Expand Up @@ -110,9 +110,9 @@ def __repr__(self):
@property
def summary(self) -> pd.DataFrame:
table = pd.DataFrame(columns=DOF_FIELDS)
for i, dof in enumerate(self.dofs):
for dof in self.dofs:
for attr in table.columns:
table.loc[i, attr] = getattr(dof, attr)
table.loc[dof.name, attr] = getattr(dof, attr)

# convert dtypes
for attr in ["readback", "lower_limit", "upper_limit"]:
Expand Down
5 changes: 3 additions & 2 deletions bloptools/bayesian/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

numeric = Union[float, int]

DEFAULT_MINIMUM_SNR = 1e1
OBJ_FIELDS = ["description", "target", "limits", "weight", "log", "n", "snr", "min_snr"]
DEFAULT_MINIMUM_SNR = 1e2
OBJ_FIELDS = ["description", "target", "active", "limits", "weight", "log", "n", "snr", "min_snr"]


class DuplicateNameError(ValueError):
Expand All @@ -30,6 +30,7 @@ class Objective:
target: Union[float, str] = "max"
log: bool = False
weight: numeric = 1.0
active: bool = True
limits: Tuple[numeric, numeric] = None
min_snr: numeric = DEFAULT_MINIMUM_SNR
units: str = None
Expand Down
3 changes: 2 additions & 1 deletion bloptools/bayesian/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,14 @@ def _plot_acqf_many_dofs(

# test_inputs has shape (..., 1, n_active_dofs)
test_inputs = agent.test_inputs_grid() if gridded else agent.test_inputs(n=1024)
*test_dim, input_dim = test_inputs.shape
test_x = test_inputs[..., 0, axes[0]].detach().squeeze().numpy()
test_y = test_inputs[..., 0, axes[1]].detach().squeeze().numpy()

for iacq_func, acq_func_identifier in enumerate(acq_funcs):
acq_func, acq_func_meta = acquisition.get_acquisition_function(agent, acq_func_identifier)

test_acqf = acq_func(test_inputs).detach().squeeze().numpy()
test_acqf = acq_func(test_inputs.reshape(-1, 1, input_dim)).detach().reshape(test_dim).squeeze().numpy()

if gridded:
agent.acq_axes[iacq_func].set_title(acq_func_meta["name"])
Expand Down
2 changes: 1 addition & 1 deletion bloptools/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def agent(db):
@pytest.fixture(scope="function")
def multi_agent(db):
"""
A simple agent minimizing two Styblinski-Tang functions
A simple agent minimizing two Himmelblau's functions
"""

def digestion(db, uid):
Expand Down

0 comments on commit 34323e7

Please sign in to comment.