Skip to content

Commit

Permalink
rebased on main
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Apr 23, 2024
1 parent d28c620 commit e0f312c
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 51 deletions.
74 changes: 37 additions & 37 deletions src/blop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def sample(self, n: int = DEFAULT_MAX_SAMPLES, method: str = "quasi-random") ->
How to sample the points. Must be one of 'quasi-random', 'random', or 'grid'.
"""

active_dofs = self.dofs.subset(active=True)
active_dofs = self.dofs(active=True)

if method == "quasi-random":
X = utils.normalized_sobol_sampler(n, d=len(active_dofs))
Expand All @@ -204,7 +204,7 @@ def sample(self, n: int = DEFAULT_MAX_SAMPLES, method: str = "quasi-random") ->
else:
raise ValueError("'method' argument must be one of ['quasi-random', 'random', 'grid'].")

return self.dofs.subset(active=True).untransform(X)
return self.dofs(active=True).untransform(X)

def ask(self, acqf="qei", n=1, route=True, sequential=True, upsample=1, **acqf_kwargs):
"""Ask the agent for the best point to sample, given an acquisition function.
Expand All @@ -228,8 +228,8 @@ def ask(self, acqf="qei", n=1, route=True, sequential=True, upsample=1, **acqf_k

start_time = ttime.monotonic()

active_dofs = self.dofs.subset(active=True)
active_objs = self.objectives.subset(active=True)
active_dofs = self.dofs(active=True)
active_objs = self.objectives(active=True)

# these are the fake acquisiton functions that we don't need to construct
if acqf_config["name"] in ["quasi-random", "random", "grid"]:
Expand Down Expand Up @@ -271,11 +271,11 @@ def ask(self, acqf="qei", n=1, route=True, sequential=True, upsample=1, **acqf_k

# this includes both RO and non-RO DOFs.
# and is in the transformed model space
candidates = self.dofs.subset(active=True).untransform(candidates).numpy()
candidates = self.dofs(active=True).untransform(candidates).numpy()

p = self.posterior(candidates) if hasattr(self, "model") else None

active_dofs = self.dofs.subset(active=True)
active_dofs = self.dofs(active=True)

points = candidates[..., ~active_dofs.read_only]
read_only_values = candidates[..., active_dofs.read_only]
Expand Down Expand Up @@ -354,7 +354,7 @@ def tell(
self.table.index = np.arange(len(self.table))

if update_models:
for obj in self.objectives.subset(active=True):
for obj in self.objectives(active=True):
t0 = ttime.monotonic()

cached_hypers = obj.model.state_dict() if hasattr(obj, "model") else None
Expand Down Expand Up @@ -414,7 +414,7 @@ def learn(
"""

if self.sample_center_on_init and not self.initialized:
center_inputs = np.atleast_2d(self.dofs.subset(active=True, read_only=False).search_domain.mean(axis=1))
center_inputs = np.atleast_2d(self.dofs(active=True, read_only=False).search_domain.mean(axis=1))
new_table = yield from self.acquire(center_inputs)
new_table.loc[:, "acqf"] = "sample_center_on_init"

Expand Down Expand Up @@ -448,7 +448,7 @@ def view(self, item: str = "mean", cmap: str = "turbo", max_inputs: int = 2**16)
self.viewer = napari.Viewer()

if item in ["mean", "error"]:
for obj in self.objectives.subset(active=True):
for obj in self.objectives(active=True):
p = obj.model.posterior(test_grid)

if item == "mean":
Expand Down Expand Up @@ -485,7 +485,7 @@ def acquire(self, acquisition_inputs):
raise ValueError("Cannot run acquistion without databroker instance!")

try:
acquisition_devices = self.dofs.subset(active=True, read_only=False).devices
acquisition_devices = self.dofs(active=True, read_only=False).devices
uid = yield from self.acquisition_plan(
acquisition_devices,
acquisition_inputs.astype(float),
Expand All @@ -501,8 +501,8 @@ def acquire(self, acquisition_inputs):
if not self.tolerate_acquisition_errors:
raise error
logging.warning(f"Error in acquisition/digestion: {repr(error)}")
products = pd.DataFrame(acquisition_inputs, columns=self.dofs.subset(active=True, read_only=False).names)
for obj in self.objectives.subset(active=True):
products = pd.DataFrame(acquisition_inputs, columns=self.dofs(active=True, read_only=False).names)
for obj in self.objectives(active=True):
products.loc[:, obj.name] = np.nan

if not len(acquisition_inputs) == len(products):
Expand All @@ -521,7 +521,7 @@ def reset(self):
"""Reset the agent."""
self.table = pd.DataFrame()

for obj in self.objectives.subset(active=True):
for obj in self.objectives(active=True):
if hasattr(obj, "model"):
del obj.model

Expand Down Expand Up @@ -556,16 +556,16 @@ def benchmark(
@property
def model(self):
"""A model encompassing all the fitnesses and constraints."""
active_objs = self.objectives.subset(active=True)
active_objs = self.objectives(active=True)
return ModelListGP(*[obj.model for obj in active_objs]) if len(active_objs) > 1 else active_objs[0].model

def posterior(self, x):
"""A model encompassing all the objectives. A single GP in the single-objective case, or a model list."""
return self.model.posterior(self.dofs.transform(torch.tensor(x)))
return self.model.posterior(self.dofs(active=True).transform(torch.tensor(x)))

@property
def fitness_model(self):
active_fitness_models = self.objectives.subset(active=True, kind="fitness")
active_fitness_models = self.objectives(active=True, kind="fitness")
if len(active_fitness_models) == 0:
raise ValueError("Having no fitness objectives is unhandled.")
if len(active_fitness_models) == 1:
Expand All @@ -574,14 +574,14 @@ def fitness_model(self):

@property
def evaluated_constraints(self):
constraint_objectives = self.objectives.subset(kind="constraint")
constraint_objectives = self.objectives(kind="constraint")
if len(constraint_objectives):
return torch.cat([obj.constrain(self.raw_targets(obj.name)) for obj in constraint_objectives], dim=-1)
else:
return torch.ones(size=(len(self.table), 0), dtype=torch.bool)

def fitness_scalarization(self, weights="default"):
fitness_objectives = self.objectives.subset(active=True, kind="fitness")
fitness_objectives = self.objectives(active=True, kind="fitness")
if weights == "default":
weights = torch.tensor([obj.weight for obj in fitness_objectives], dtype=torch.double)
elif weights == "equal":
Expand Down Expand Up @@ -682,7 +682,7 @@ def _construct_model(self, obj, skew_dims=None):
outcome_transform=outcome_transform,
)

obj.model_dofs = set(self.dofs.subset(active=True).names) # if these change, retrain the model on self.ask()
obj.model_dofs = set(self.dofs(active=True).names) # if these change, retrain the model on self.ask()

if trusted.all():
obj.validity_conjugate_model = None
Expand All @@ -707,13 +707,13 @@ def _construct_model(self, obj, skew_dims=None):

def _construct_all_models(self):
"""Construct a model for each objective."""
for obj in self.objectives.subset(active=True):
for obj in self.objectives(active=True):
self._construct_model(obj)

def _train_all_models(self, **kwargs):
"""Fit all of the agent's models. All kwargs are passed to `botorch.fit.fit_gpytorch_mll`."""
t0 = ttime.monotonic()
for obj in self.objectives.subset(active=True):
for obj in self.objectives(active=True):
self._train_model(obj.model)
if obj.validity_conjugate_model is not None:
self._train_model(obj.validity_conjugate_model)
Expand Down Expand Up @@ -744,7 +744,7 @@ def _latent_dim_tuples(self, obj_index=None):
obj = self.objectives[obj_index]

latent_group_index = {}
for dof in self.dofs.subset(active=True):
for dof in self.dofs(active=True):
latent_group_index[dof.name] = dof.name
for group_index, latent_group in enumerate(obj.latent_groups):
if dof.name in latent_group:
Expand All @@ -754,13 +754,13 @@ def _latent_dim_tuples(self, obj_index=None):
return [tuple(np.where(uinv == i)[0]) for i in range(len(u))]

@property
def _sample_domain(self):
def sample_domain(self):
"""
Returns a (2, n_active_dof) array of lower and upper bounds for dofs.
Read-only DOFs are set to exactly their last known value.
Discrete DOFs are relaxed to some continuous domain.
"""
return self.dofs.subset(active=True).transform(self.dofs.subset(active=True).search_domain.T)
return self.dofs(active=True).transform(self.dofs(active=True).search_domain.T)

@property
def input_normalization(self):
Expand Down Expand Up @@ -816,15 +816,15 @@ def forget(self, last=None, index=None, train=True):
raise ValueError("Must supply either 'last' or 'index'.")

def _set_hypers(self, hypers):
for obj in self.objectives.subset(active=True):
for obj in self.objectives(active=True):
obj.model.load_state_dict(hypers[obj.name])
self.validity_constraint.load_state_dict(hypers["validity_constraint"])

def constraint(self, x):
x = self.dofs.subset(active=True).transform(x)
x = self.dofs(active=True).transform(x)

p = torch.ones(x.shape[:-1])
for obj in self.objectives.subset(active=True):
for obj in self.objectives(active=True):
# if the targeting constraint is non-trivial
# if obj.kind == "constraint":
# p *= obj.targeting_constraint(x)
Expand Down Expand Up @@ -892,14 +892,14 @@ def raw_inputs(self, index=None, **subset_kwargs):
Get the raw, untransformed inputs for a DOF (or for a subset).
"""
if index is None:
return torch.cat([self.raw_inputs(dof.name) for dof in self.dofs.subset(**subset_kwargs)], dim=-1)
return torch.cat([self.raw_inputs(dof.name) for dof in self.dofs(**subset_kwargs)], dim=-1)
return torch.tensor(self.table.loc[:, self.dofs[index].name].values, dtype=torch.double).unsqueeze(-1)

def train_inputs(self, index=None, **subset_kwargs):
"""A two-dimensional tensor of all DOF values."""

if index is None:
return torch.cat([self.train_inputs(index=dof.name) for dof in self.dofs.subset(**subset_kwargs)], dim=-1)
return torch.cat([self.train_inputs(index=dof.name) for dof in self.dofs(**subset_kwargs)], dim=-1)

dof = self.dofs[index]
raw_inputs = self.raw_inputs(index=index, **subset_kwargs)
Expand All @@ -915,14 +915,14 @@ def raw_targets(self, index=None, **subset_kwargs):
Get the raw, untransformed inputs for an objective (or for a subset).
"""
if index is None:
return torch.cat([self.raw_targets(index=obj.name) for obj in self.objectives.subset(**subset_kwargs)], dim=-1)
return torch.cat([self.raw_targets(index=obj.name) for obj in self.objectives(**subset_kwargs)], dim=-1)
return torch.tensor(self.table.loc[:, self.objectives[index].name].values, dtype=torch.double).unsqueeze(-1)

def train_targets(self, index=None, **subset_kwargs):
"""Returns the values associated with an objective name."""

if index is None:
return torch.cat([self.train_targets(obj.name) for obj in self.objectives.subset(**subset_kwargs)], dim=-1)
return torch.cat([self.train_targets(obj.name) for obj in self.objectives(**subset_kwargs)], dim=-1)

obj = self.objectives[index]
raw_targets = self.raw_targets(index=index, **subset_kwargs)
Expand Down Expand Up @@ -973,10 +973,10 @@ def plot_objectives(self, axes: Tuple = (0, 1), **kwargs):
A tuple specifying which DOFs to plot as a function of. Can be either an int or the name of DOFs.
"""

if len(self.dofs.subset(active=True, read_only=False)) == 1:
if len(self.objectives.subset(active=True, kind="fitness")) > 0:
if len(self.dofs(active=True, read_only=False)) == 1:
if len(self.objectives(active=True, kind="fitness")) > 0:
plotting._plot_fitness_objs_one_dof(self, **kwargs)
if len(self.objectives.subset(active=True, kind="constraint")) > 0:
if len(self.objectives(active=True, kind="constraint")) > 0:
plotting._plot_constraint_objs_one_dof(self, **kwargs)
else:
plotting._plot_objs_many_dofs(self, axes=axes, **kwargs)
Expand All @@ -991,7 +991,7 @@ def plot_acquisition(self, acqf="ei", axes: Tuple = (0, 1), **kwargs):
axes :
A tuple specifying which DOFs to plot as a function of. Can be either an int or the name of DOFs.
"""
if len(self.dofs.subset(active=True, read_only=False)) == 1:
if len(self.dofs(active=True, read_only=False)) == 1:
plotting._plot_acqf_one_dof(self, acqfs=np.atleast_1d(acqf), **kwargs)

else:
Expand All @@ -1005,7 +1005,7 @@ def plot_validity(self, axes: Tuple = (0, 1), **kwargs):
axes :
A tuple specifying which DOFs to plot as a function of. Can be either an int or the name of DOFs.
"""
if len(self.dofs.subset(active=True, read_only=False)) == 1:
if len(self.dofs(active=True, read_only=False)) == 1:
plotting._plot_valid_one_dof(self, **kwargs)

else:
Expand All @@ -1017,7 +1017,7 @@ def plot_history(self, **kwargs):

@property
def latent_transforms(self):
return {obj.name: obj.model.covar_module.latent_transform for obj in self.objectives.subset(active=True)}
return {obj.name: obj.model.covar_module.latent_transform for obj in self.objectives(active=True)}

def plot_pareto_front(self, **kwargs):
"""Plot the improvement of the agent over time."""
Expand Down
2 changes: 1 addition & 1 deletion src/blop/bayesian/acquisition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _construct_acqf(agent, acqf_name, **acqf_kwargs):
constraint=agent.constraint,
model=agent.fitness_model,
# X_baseline=agent.input_normalization.forward(agent.train_inputs())[],
X_baseline=agent.dofs.transform(agent.train_inputs(active=True)),
X_baseline=agent.dofs(active=True).transform(agent.train_inputs(active=True)),
prune_baseline=True,
**acqf_kwargs,
)
Expand Down
26 changes: 13 additions & 13 deletions src/blop/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


def _plot_fitness_objs_one_dof(agent, size=16, lw=1e0):
fitness_objs = agent.objectives.subset(kind="fitness")
fitness_objs = agent.objectives(kind="fitness")

agent.obj_fig, agent.obj_axes = plt.subplots(
len(fitness_objs),
Expand All @@ -26,11 +26,11 @@ def _plot_fitness_objs_one_dof(agent, size=16, lw=1e0):

agent.obj_axes = np.atleast_1d(agent.obj_axes)

x_dof = agent.dofs.subset(active=True)[0]
x_dof = agent.dofs(active=True)[0]
x_values = agent.table.loc[:, x_dof.device.name].values

test_inputs = agent.sample(method="grid")
test_model_inputs = agent.dofs.transform(test_inputs)
test_model_inputs = agent.dofs(active=True).transform(test_inputs)

for obj_index, obj in enumerate(fitness_objs):
obj_values = agent.train_targets(obj.name).squeeze(-1).numpy()
Expand Down Expand Up @@ -62,7 +62,7 @@ def _plot_fitness_objs_one_dof(agent, size=16, lw=1e0):


def _plot_constraint_objs_one_dof(agent, size=16, lw=1e0):
constraint_objs = agent.objectives.subset(kind="constraint")
constraint_objs = agent.objectives(kind="constraint")

agent.obj_fig, agent.obj_axes = plt.subplots(
len(constraint_objs),
Expand All @@ -74,11 +74,11 @@ def _plot_constraint_objs_one_dof(agent, size=16, lw=1e0):

agent.obj_axes = np.atleast_2d(agent.obj_axes)

x_dof = agent.dofs.subset(active=True)[0]
x_dof = agent.dofs(active=True)[0]
x_values = agent.table.loc[:, x_dof.device.name].values

test_inputs = agent.sample(method="grid")
test_model_inputs = agent.dofs.transform(test_inputs)
test_model_inputs = agent.dofs(active=True).transform(test_inputs)

for obj_index, obj in enumerate(constraint_objs):
val_ax = agent.obj_axes[obj_index, 0]
Expand Down Expand Up @@ -129,7 +129,7 @@ def _plot_objs_many_dofs(agent, axes=(0, 1), shading="nearest", cmap=DEFAULT_COL
Axes represents which active, non-read-only axes to plot with
"""

plottable_dofs = agent.dofs.subset(active=True, read_only=False)
plottable_dofs = agent.dofs(active=True, read_only=False)

if gridded is None:
gridded = len(plottable_dofs) == 2
Expand All @@ -155,7 +155,7 @@ def _plot_objs_many_dofs(agent, axes=(0, 1), shading="nearest", cmap=DEFAULT_COL
test_x = test_inputs[..., 0, axes[0]].detach().squeeze().numpy()
test_y = test_inputs[..., 0, axes[1]].detach().squeeze().numpy()

model_inputs = agent.dofs.subset(active=True).transform(test_inputs)
model_inputs = agent.dofs(active=True).transform(test_inputs)

for obj_index, obj in enumerate(agent.objectives):
targets = agent.train_targets(obj.name).squeeze(-1).numpy()
Expand Down Expand Up @@ -324,7 +324,7 @@ def _plot_acqf_one_dof(agent, acqfs, lw=1e0, **kwargs):
)

agent.acq_axes = np.atleast_1d(agent.acq_axes)
x_dof = agent.dofs.subset(active=True)[0]
x_dof = agent.dofs(active=True)[0]

test_inputs = agent.sample(method="grid")

Expand Down Expand Up @@ -355,7 +355,7 @@ def _plot_acqf_many_dofs(
constrained_layout=True,
)

plottable_dofs = agent.dofs.subset(active=True, read_only=False)
plottable_dofs = agent.dofs(active=True, read_only=False)

if gridded is None:
gridded = len(plottable_dofs) == 2
Expand Down Expand Up @@ -412,7 +412,7 @@ def _plot_acqf_many_dofs(
def _plot_valid_one_dof(agent, size=16, lw=1e0):
agent.valid_fig, agent.valid_ax = plt.subplots(1, 1, figsize=(6, 4 * len(agent.objectives)), constrained_layout=True)

x_dof = agent.dofs.subset(active=True)[0]
x_dof = agent.dofs(active=True)[0]
x_values = agent.table.loc[:, x_dof.device.name].values

test_inputs = agent.sample(method="grid")
Expand All @@ -426,7 +426,7 @@ def _plot_valid_one_dof(agent, size=16, lw=1e0):
def _plot_valid_many_dofs(agent, axes=[0, 1], shading="nearest", cmap=DEFAULT_COLORMAP, size=16, gridded=None):
agent.valid_fig, agent.valid_axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True)

plottable_dofs = agent.dofs.subset(active=True, read_only=False)
plottable_dofs = agent.dofs(active=True, read_only=False)

if gridded is None:
gridded = len(plottable_dofs) == 2
Expand Down Expand Up @@ -537,7 +537,7 @@ def inspect_beam(agent, index, border=None):


def _plot_pareto_front(agent, obj_indices=(0, 1)):
f_objs = agent.objectives.subset(kind="fitness")
f_objs = agent.objectives(kind="fitness")
(i, j) = obj_indices

if len(f_objs) < 2:
Expand Down

0 comments on commit e0f312c

Please sign in to comment.