Skip to content

Commit

Permalink
better controls for dofs and objs
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Jan 12, 2024
1 parent 85ce3a3 commit 616b2df
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 111 deletions.
3 changes: 3 additions & 0 deletions blop/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from . import utils # noqa F401
from ._version import get_versions
from .dofs import DOF # noqa F401
from .objectives import Objective # noqa F401

__version__ = get_versions()["version"]
del get_versions
2 changes: 0 additions & 2 deletions blop/bayesian/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
from .agent import * # noqa F401
from .dofs import * # noqa F401
from .objectives import * # noqa F401
39 changes: 13 additions & 26 deletions blop/bayesian/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from ophyd import Signal

from .. import utils
from ..dofs import DOF, DOFList
from ..objectives import Objective, ObjectiveList
from . import acquisition, models, plotting
from .digestion import default_digestion_function
from .dofs import DOF, DOFList
from .objectives import Objective, ObjectiveList
from .plans import default_acquisition_plan
from .transforms import TargetingPosteriorTransform

Expand Down Expand Up @@ -420,7 +420,7 @@ def learn(
"""

if self.sample_center_on_init and not self.initialized:
center_inputs = np.atleast_2d(self.dofs.subset(active=True, read_only=False).limits.mean(axis=1))
center_inputs = np.atleast_2d(self.dofs.subset(active=True, read_only=False).search_bounds.mean(axis=1))
new_table = yield from self.acquire(center_inputs)
new_table.loc[:, "acq_func"] = "sample_center_on_init"

Expand Down Expand Up @@ -453,8 +453,8 @@ def _construct_model(self, obj, skew_dims=None):

likelihood = gpytorch.likelihoods.GaussianLikelihood(
noise_constraint=gpytorch.constraints.Interval(
torch.tensor(1e-4).square(),
torch.tensor(1 / obj.min_snr).square(),
torch.tensor(obj.min_noise),
torch.tensor(obj.max_noise),
),
# noise_prior=gpytorch.priors.torch_priors.LogNormalPrior(loc=loc, scale=scale),
)
Expand Down Expand Up @@ -557,20 +557,6 @@ def scalarizing_transform(self):
def targeting_transform(self):
return TargetingPosteriorTransform(weights=self.objective_weights_torch, targets=self.objectives.targets)

@property
def pseudo_targets(self):
"""Targets for the posterior transform"""
return torch.tensor(
[
self.train_targets(active=True)[..., i].max()
if t == "max"
else self.train_targets(active=True)[..., i].min()
if t == "min"
else t
for i, t in enumerate(self.objectives.targets)
]
)

@property
def scalarized_objectives(self):
"""Returns a (n_obs,) array of scalarized objectives"""
Expand Down Expand Up @@ -607,7 +593,9 @@ def test_inputs_grid(self, max_inputs=MAX_TEST_INPUTS):
for tensor in torch.meshgrid(
*[
torch.linspace(lower_limit, upper_limit, n_side)
for (lower_limit, upper_limit), n_side in zip(self.dofs.subset(active=True).limits, n_sides)
for (lower_limit, upper_limit), n_side in zip(
self.dofs.subset(active=True).search_bounds, n_sides
)
],
indexing="ij",
)
Expand All @@ -625,10 +613,9 @@ def test_inputs(self, n=MAX_TEST_INPUTS):
@property
def acquisition_function_bounds(self):
"""Returns a (2, n_active_dof) array of bounds for the acquisition function"""
active_dofs = self.dofs.subset(active=True)

acq_func_lower_bounds = [dof.lower_limit if not dof.read_only else dof.readback for dof in active_dofs]
acq_func_upper_bounds = [dof.upper_limit if not dof.read_only else dof.readback for dof in active_dofs]
acq_func_lower_bounds = np.where(self.dofs.read_only, self.dofs.readback, self.dofs.search_lower_bounds)
acq_func_upper_bounds = np.where(self.dofs.read_only, self.dofs.readback, self.dofs.search_upper_bounds)

return torch.tensor(np.vstack([acq_func_lower_bounds, acq_func_upper_bounds]), dtype=torch.double)

Expand All @@ -652,7 +639,7 @@ def input_transform(self):

def _subset_input_transform(self, active=None, read_only=None, tags=[]):
# torch likes limits to be (2, n_dof) and not (n_dof, 2)
torch_limits = torch.tensor(self.dofs.subset(active, read_only, tags).limits.T, dtype=torch.double)
torch_limits = torch.tensor(self.dofs.subset(active, read_only, tags).search_bounds.T, dtype=torch.double)
offset = torch_limits.min(dim=0).values
coefficient = torch_limits.max(dim=0).values - offset
return botorch.models.transforms.input.AffineInputTransform(
Expand Down Expand Up @@ -792,7 +779,7 @@ def train_inputs(self, dof_name=None, **subset_kwargs):
inputs = self.table.loc[:, dof.name].values.copy()

# check that inputs values are inside acceptable values
valid = (inputs >= dof.limits[0]) & (inputs <= dof.limits[1])
valid = (inputs >= dof.trust_bounds[0]) & (inputs <= dof.trust_bounds[1])
inputs = np.where(valid, inputs, np.nan)

# transform if needed
Expand All @@ -811,7 +798,7 @@ def train_targets(self, obj_name=None, **subset_kwargs):
targets = self.table.loc[:, obj.name].values.copy()

# check that targets values are inside acceptable values
valid = (targets >= obj.limits[0]) & (targets <= obj.limits[1])
valid = (targets >= obj.trust_bounds[0]) & (targets <= obj.trust_bounds[1])
targets = np.where(valid, targets, np.nan)

# transform if needed
Expand Down
26 changes: 13 additions & 13 deletions blop/bayesian/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _plot_objs_one_dof(agent, size=16, lw=1e0):
alpha=0.5**z,
)

agent.obj_axes[obj_index].set_xlim(*x_dof.limits)
agent.obj_axes[obj_index].set_xlim(*x_dof.search_bounds)
agent.obj_axes[obj_index].set_xlabel(x_dof.label)
agent.obj_axes[obj_index].set_ylabel(obj.label)

Expand All @@ -67,10 +67,10 @@ def _plot_objs_many_dofs(agent, axes=(0, 1), shading="nearest", cmap=DEFAULT_COL

agent.obj_fig, agent.obj_axes = plt.subplots(
len(agent.objectives),
3,
figsize=(10, 4 * len(agent.objectives)),
4,
figsize=(12, 4 * len(agent.objectives)),
constrained_layout=True,
dpi=256,
dpi=160,
)

agent.obj_axes = np.atleast_2d(agent.obj_axes)
Expand Down Expand Up @@ -144,7 +144,7 @@ def _plot_objs_many_dofs(agent, axes=(0, 1), shading="nearest", cmap=DEFAULT_COL
obj_cbar.set_label(obj.label)
err_cbar.set_label(f"{obj.label} error")

col_names = ["samples", "posterior mean", "posterior std. dev."]
col_names = ["samples", "posterior mean", "posterior std. dev.", "fitness"]

pad = 5

Expand Down Expand Up @@ -179,8 +179,8 @@ def _plot_objs_many_dofs(agent, axes=(0, 1), shading="nearest", cmap=DEFAULT_COL
for ax in agent.obj_axes.ravel():
ax.set_xlabel(x_dof.label)
ax.set_ylabel(y_dof.label)
ax.set_xlim(*x_dof.limits)
ax.set_ylim(*y_dof.limits)
ax.set_xlim(*x_dof.search_bounds)
ax.set_ylim(*y_dof.search_bounds)


def _plot_acqf_one_dof(agent, acq_funcs, lw=1e0, **kwargs):
Expand All @@ -205,7 +205,7 @@ def _plot_acqf_one_dof(agent, acq_funcs, lw=1e0, **kwargs):

agent.acq_axes[iacq_func].plot(test_inputs.squeeze(-2), test_acqf, lw=lw, color=color)

agent.acq_axes[iacq_func].set_xlim(*x_dof.limits)
agent.acq_axes[iacq_func].set_xlim(*x_dof.search_bounds)
agent.acq_axes[iacq_func].set_xlabel(x_dof.label)
agent.acq_axes[iacq_func].set_ylabel(acq_func_meta["name"])

Expand Down Expand Up @@ -267,8 +267,8 @@ def _plot_acqf_many_dofs(
for ax in agent.acq_axes.ravel():
ax.set_xlabel(x_dof.label)
ax.set_ylabel(y_dof.label)
ax.set_xlim(*x_dof.limits)
ax.set_ylim(*y_dof.limits)
ax.set_xlim(*x_dof.search_bounds)
ax.set_ylim(*y_dof.search_bounds)


def _plot_valid_one_dof(agent, size=16, lw=1e0):
Expand All @@ -282,7 +282,7 @@ def _plot_valid_one_dof(agent, size=16, lw=1e0):

agent.valid_ax.scatter(x_values, agent.all_objectives_valid, s=size)
agent.valid_ax.plot(test_inputs.squeeze(-2), constraint, lw=lw)
agent.valid_ax.set_xlim(*x_dof.limits)
agent.valid_ax.set_xlim(*x_dof.search_bounds)


def _plot_valid_many_dofs(agent, axes=[0, 1], shading="nearest", cmap=DEFAULT_COLORMAP, size=16, gridded=None):
Expand Down Expand Up @@ -327,8 +327,8 @@ def _plot_valid_many_dofs(agent, axes=[0, 1], shading="nearest", cmap=DEFAULT_CO
for ax in agent.acq_axes.ravel():
ax.set_xlabel(x_dof.label)
ax.set_ylabel(y_dof.label)
ax.set_xlim(*x_dof.limits)
ax.set_ylim(*y_dof.limits)
ax.set_xlim(*x_dof.search_bounds)
ax.set_ylim(*y_dof.search_bounds)


def _plot_history(agent, x_key="index", show_all_objs=False):
Expand Down
22 changes: 12 additions & 10 deletions blop/bayesian/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@
from torch import Tensor


def targeting_transform(y, target):
if target == "min":
y = -y
elif not isinstance(target, tuple):
y = -(y - target).abs()
else:
y = -((y - 0.5 * (target[1] + target[0])).abs() - 0.5 * (target[1] - target[0])).clamp(min=0)
return y


class TargetingPosteriorTransform(PosteriorTransform):
r"""An affine posterior transform for scalarizing multi-output posteriors."""

Expand All @@ -25,20 +35,12 @@ def __init__(self, weights: Tensor, targets: Tensor) -> None:

def sampled_transform(self, y):
for i, target in enumerate(self.targets):
if target == "min":
y[..., i] = -y[..., i]
elif target != "max":
y[..., i] = -(y[..., i] - target).abs()

y[..., i] = targeting_transform(y[..., i], target)
return y @ self.weights.unsqueeze(-1)

def mean_transform(self, mean, var):
for i, target in enumerate(self.targets):
if target == "min":
mean[..., i] = -mean[..., i]
elif target != "max":
mean[..., i] = -(mean[..., i] - target).abs()

mean[..., i] = targeting_transform(mean[..., i], target)
return mean @ self.weights.unsqueeze(-1)

def variance_transform(self, mean, var):
Expand Down
Loading

0 comments on commit 616b2df

Please sign in to comment.