Skip to content

Commit

Permalink
added targeting posterior transform
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Nov 8, 2023
1 parent cc0e885 commit e3e18e5
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 70 deletions.
13 changes: 6 additions & 7 deletions bloptools/bayesian/acquisition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

import yaml
from botorch.acquisition.objective import ScalarizedPosteriorTransform

from . import analytic, monte_carlo
from .analytic import * # noqa F401
Expand Down Expand Up @@ -47,7 +46,7 @@ def get_acquisition_function(agent, identifier="qei", return_metadata=True, verb
constraint=agent.constraint,
model=agent.model,
best_f=agent.max_scalarized_objective,
posterior_transform=ScalarizedPosteriorTransform(weights=agent.objective_weights_torch, offset=0),
posterior_transform=agent.targeting_transform,
)
acq_func_meta = {"name": acq_func_name, "args": {}}

Expand All @@ -56,7 +55,7 @@ def get_acquisition_function(agent, identifier="qei", return_metadata=True, verb
constraint=agent.constraint,
model=agent.model,
best_f=agent.max_scalarized_objective,
posterior_transform=ScalarizedPosteriorTransform(weights=agent.objective_weights_torch, offset=0),
posterior_transform=agent.targeting_transform,
)
acq_func_meta = {"name": acq_func_name, "args": {}}

Expand All @@ -65,7 +64,7 @@ def get_acquisition_function(agent, identifier="qei", return_metadata=True, verb
constraint=agent.constraint,
model=agent.model,
best_f=agent.max_scalarized_objective,
posterior_transform=ScalarizedPosteriorTransform(weights=agent.objective_weights_torch, offset=0),
posterior_transform=agent.targeting_transform,
)
acq_func_meta = {"name": acq_func_name, "args": {}}

Expand All @@ -74,7 +73,7 @@ def get_acquisition_function(agent, identifier="qei", return_metadata=True, verb
constraint=agent.constraint,
model=agent.model,
best_f=agent.max_scalarized_objective,
posterior_transform=ScalarizedPosteriorTransform(weights=agent.objective_weights_torch, offset=0),
posterior_transform=agent.targeting_transform,
)
acq_func_meta = {"name": acq_func_name, "args": {}}

Expand Down Expand Up @@ -103,7 +102,7 @@ def get_acquisition_function(agent, identifier="qei", return_metadata=True, verb
constraint=agent.constraint,
model=agent.model,
beta=beta,
posterior_transform=ScalarizedPosteriorTransform(weights=agent.objective_weights_torch, offset=0),
posterior_transform=agent.targeting_transform,
)
acq_func_meta = {"name": acq_func_name, "args": {"beta": beta}}

Expand All @@ -114,7 +113,7 @@ def get_acquisition_function(agent, identifier="qei", return_metadata=True, verb
constraint=agent.constraint,
model=agent.model,
beta=beta,
posterior_transform=ScalarizedPosteriorTransform(weights=agent.objective_weights_torch, offset=0),
posterior_transform=agent.targeting_transform,
)
acq_func_meta = {"name": acq_func_name, "args": {"beta": beta}}

Expand Down
105 changes: 73 additions & 32 deletions bloptools/bayesian/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pandas as pd
import scipy as sp
import torch
from botorch.acquisition.objective import ScalarizedPosteriorTransform
from botorch.models.deterministic import GenericDeterministicModel
from botorch.models.model_list_gp_regression import ModelListGP
from databroker import Broker
Expand All @@ -26,6 +27,7 @@
from .dofs import DOF, DOFList
from .objectives import Objective, ObjectiveList
from .plans import default_acquisition_plan
from .transforms import TargetingPosteriorTransform

warnings.filterwarnings("ignore", category=botorch.exceptions.warnings.InputDataWarning)

Expand Down Expand Up @@ -144,20 +146,18 @@ def tell(self, x: Mapping, y: Mapping, metadata=None, append=True, train_models=
def _update_models(self, train=True, skew_dims=None, a_priori_hypers=None):
skew_dims = skew_dims if skew_dims is not None else self.latent_dim_tuples

# if self.initialized:
# cached_hypers = self.hypers

inputs = self.table.loc[:, self.dofs.subset(active=True).names].values.astype(float)

for i, obj in enumerate(self.objectives):
self.table.loc[:, f"{obj.key}_fitness"] = targets = self._get_objective_targets(i)
train_index = ~np.isnan(targets)
values = self.get_objective_targets(i)

train_index = ~np.isnan(values)

if not train_index.sum() >= 2:
raise ValueError("There must be at least two valid data points per objective!")

train_inputs = torch.tensor(inputs[train_index], dtype=torch.double)
train_targets = torch.tensor(targets[train_index], dtype=torch.double).unsqueeze(-1) # .unsqueeze(0)
train_values = torch.tensor(values[train_index], dtype=torch.double).unsqueeze(-1) # .unsqueeze(0)

# for constructing the log normal noise prior
# target_snr = 2e2
Expand All @@ -176,7 +176,7 @@ def _update_models(self, train=True, skew_dims=None, a_priori_hypers=None):

obj.model = models.LatentGP(
train_inputs=train_inputs,
train_targets=train_targets,
train_targets=train_values,
likelihood=likelihood,
skew_dims=skew_dims,
input_transform=self._subset_input_transform(active=True),
Expand Down Expand Up @@ -343,7 +343,7 @@ def acquire(self, acquisition_inputs):
logging.warning(f"Error in acquisition/digestion: {repr(error)}")
products = pd.DataFrame(acquisition_inputs, columns=self.dofs.subset(active=True, read_only=False).names)
for obj in self.objectives:
products.loc[:, obj.key] = np.nan
products.loc[:, obj.name] = np.nan

if not len(acquisition_inputs) == len(products):
raise ValueError("The table returned by the digestion function must be the same length as the sampled inputs!")
Expand All @@ -353,7 +353,7 @@ def acquire(self, acquisition_inputs):
def load_data(self, data_file, append=True, train_models=True):
new_table = pd.read_hdf(data_file, key="table")
x = {key: new_table.pop(key).tolist() for key in self.dofs.names}
y = {key: new_table.pop(key).tolist() for key in self.objectives.keys}
y = {key: new_table.pop(key).tolist() for key in self.objectives.names}
metadata = new_table.to_dict(orient="list")
self.tell(x=x, y=y, metadata=metadata, append=append, train_models=train_models)

Expand Down Expand Up @@ -406,7 +406,7 @@ def learn(
new_table.loc[:, "acq_func"] = acq_func_meta["name"]

x = {key: new_table.pop(key).tolist() for key in self.dofs.names}
y = {key: new_table.pop(key).tolist() for key in self.objectives.keys}
y = {key: new_table.pop(key).tolist() for key in self.objectives.names}
metadata = new_table.to_dict(orient="list")
self.tell(x=x, y=y, metadata=metadata, append=append, train_models=train_models)

Expand Down Expand Up @@ -459,49 +459,90 @@ def model(self):
"""A model encompassing all the objectives. A single GP in the single-objective case, or a model list."""
return ModelListGP(*[obj.model for obj in self.objectives]) if len(self.objectives) > 1 else self.objectives[0].model

def posterior(self, x):
"""A model encompassing all the objectives. A single GP in the single-objective case, or a model list."""
return self.model.posterior(x)

@property
def objective_weights_torch(self):
return torch.tensor(self.objectives.weights, dtype=torch.double)

def _get_objective_targets(self, i):
"""Returns the targets (what we fit to) for an objective, given the objective index."""
def get_objective_targets(self, i):
"""Returns the values associated with each objective."""

obj = self.objectives[i]

targets = self.table.loc[:, obj.key].values.copy()
targets = self.table.loc[:, obj.name].values.copy()

# check that targets values are inside acceptable values
valid = (targets > obj.limits[0]) & (targets < obj.limits[1])
targets = np.where(valid, targets, np.nan)

# transform if needed
if obj.log:
targets = np.where(valid, np.log(targets), np.nan)
if obj.target not in ["min", "max"]:
targets = -np.square(np.log(targets) - np.log(obj.target))
targets = np.where(targets > 0, np.log(targets), np.nan)

else:
if obj.target not in ["min", "max"]:
targets = -np.square(targets - obj.target)
return targets

if obj.target == "min":
targets *= -1
# def _get_objective_targets(self, i):
# """Returns the targets (what we fit to) for an objective, given the objective index."""
# obj = self.objectives[i]

return targets
# targets = self.table.loc[:, obj.name].values.copy()

# # check that targets values are inside acceptable values
# valid = (targets > obj.limits[0]) & (targets < obj.limits[1])
# targets = np.where(valid, targets, np.nan)

# # transform if needed
# if obj.log:
# targets = np.where(valid, np.log(targets), np.nan)
# if obj.target not in ["min", "max"]:
# targets = -np.square(np.log(targets) - np.log(obj.target))

# else:
# if obj.target not in ["min", "max"]:
# targets = -np.square(targets - obj.target)

# if obj.target == "min":
# targets *= -1

# return targets

@property
def scalarizing_transform(self):
return ScalarizedPosteriorTransform(weights=self.objective_weights_torch, offset=0)

@property
def targeting_transform(self):
return TargetingPosteriorTransform(weights=self.objective_weights_torch, targets=self.pseudo_targets)

@property
def n_objs(self):
"""Returns a (num_objectives x n_observations) array of objectives"""
return len(self.objectives)
def pseudo_targets(self):
"""Targets for the posterior transform"""
return torch.tensor(
[
self.objectives_targets[..., i].max()
if t == "max"
else self.objectives_targets[..., i].min()
if t == "min"
else t
for i, t in enumerate(self.objectives.targets)
]
)

@property
def objectives_targets(self):
"""Returns a (num_objectives x n_obs) array of objectives"""
return torch.cat([torch.tensor(self._get_objective_targets(i))[..., None] for i in range(self.n_objs)], dim=1)
return torch.cat(
[torch.tensor(self.get_objective_targets(i))[..., None] for i in range(len(self.objectives))], dim=1
)

@property
def scalarized_objectives(self):
"""Returns a (n_obs,) array of scalarized objectives"""
return (self.objectives_targets * self.objectives.weights).sum(axis=-1)
return self.targeting_transform.evaluate(self.objectives_targets).sum(axis=-1)
# return (self.objectives_targets * self.objectives.signed_weights).sum(axis=-1)

@property
def max_scalarized_objective(self):
Expand Down Expand Up @@ -615,7 +656,7 @@ def sampler(self, n, d):

def _set_hypers(self, hypers):
for obj in self.objectives:
obj.model.load_state_dict(hypers[obj.key])
obj.model.load_state_dict(hypers[obj.name])
self.classifier.load_state_dict(hypers["classifier"])

@property
Expand All @@ -625,17 +666,17 @@ def hypers(self):
for key, value in self.classifier.state_dict().items():
hypers["classifier"][key] = value
for obj in self.objectives:
hypers[obj.key] = {}
hypers[obj.name] = {}
for key, value in obj.model.state_dict().items():
hypers[obj.key][key] = value
hypers[obj.name][key] = value

return hypers

def save_hypers(self, filepath):
"""Save the agent's fitted hyperparameters to a given filepath."""
hypers = self.hypers
with h5py.File(filepath, "w") as f:
for model_key in hypers.keys():
for model_key in hypers.names():
f.create_group(model_key)
for param_key, param_value in hypers[model_key].items():
f[model_key].create_dataset(param_key, data=param_value)
Expand All @@ -645,7 +686,7 @@ def load_hypers(filepath):
"""Load hyperparameters from a file."""
hypers = {}
with h5py.File(filepath, "r") as f:
for model_key in f.keys():
for model_key in f.names():
hypers[model_key] = OrderedDict()
for param_key, param_value in f[model_key].items():
hypers[model_key][param_key] = torch.tensor(np.atleast_1d(param_value[()]))
Expand Down
7 changes: 4 additions & 3 deletions bloptools/bayesian/dofs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ophyd import Signal, SignalRO

DEFAULT_BOUNDS = (-5.0, +5.0)
DOF_FIELDS = ["name", "readback", "lower_limit", "upper_limit", "units", "active", "read_only", "tags"]
DOF_FIELDS = ["name", "description", "readback", "lower_limit", "upper_limit", "units", "active", "read_only", "tags"]

numeric = Union[float, int]

Expand All @@ -33,13 +33,14 @@ def _validate_dofs(dofs):
@dataclass
class DOF:
device: Signal = None
limits: Tuple[float, float] = (-10.0, 10.0)
description: str = None
name: str = None
limits: Tuple[float, float] = (-10.0, 10.0)
units: str = ""
read_only: bool = False
active: bool = True
tags: list = field(default_factory=list)
latent_group = None
latent_group: str = None

def __post_init__(self):
self.uuid = str(uuid.uuid4())
Expand Down
16 changes: 16 additions & 0 deletions bloptools/bayesian/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@ def __init__(self, train_inputs, train_targets, skew_dims=True, *args, **kwargs)
)


class TargetingGP(botorch.models.gp_regression.SingleTaskGP):
def __init__(self, train_inputs, train_targets, skew_dims=True, *args, **kwargs):
super().__init__(train_inputs, train_targets, *args, **kwargs)

self.mean_module = gpytorch.means.ConstantMean(constant_prior=gpytorch.priors.NormalPrior(loc=0, scale=1))

self.covar_module = kernels.LatentKernel(
num_inputs=train_inputs.shape[-1],
num_outputs=train_targets.shape[-1],
skew_dims=skew_dims,
diag_prior=True,
scale=True,
**kwargs
)


class LatentDirichletClassifier(LatentGP):
def __init__(self, train_inputs, train_targets, skew_dims=True, *args, **kwargs):
super().__init__(train_inputs, train_targets, skew_dims, *args, **kwargs)
Expand Down
Loading

0 comments on commit e3e18e5

Please sign in to comment.