Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add napari viewer #55

Merged
merged 4 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 150 additions & 160 deletions blop/bayesian/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import gpytorch
import h5py
import matplotlib as mpl
import napari
import numpy as np
import pandas as pd
import scipy as sp
Expand Down Expand Up @@ -119,6 +120,45 @@ def __init__(

self.n_last_trained = 0

def view(self, item: str = "mean", cmap: str = "turbo", max_inputs: int = MAX_TEST_INPUTS):
"""
Use napari to see a high-dimensional array.

Parameters
----------
item : str
The thing to be viewed. Either 'mean', 'error', or an acquisition function.
"""

test_grid = self.test_inputs_grid(max_inputs=max_inputs)

self.viewer = napari.Viewer()

if item in ["mean", "error"]:
for obj in self.objectives:
p = obj.model.posterior(test_grid)

if item == "mean":
mean = p.mean.detach().numpy()[..., 0, 0]
self.viewer.add_image(data=mean, name=f"{obj.name}_mean", colormap=cmap)

if item == "error":
error = np.sqrt(p.variance.detach().numpy()[..., 0, 0])
self.viewer.add_image(data=error, name=f"{obj.name}_error", colormap=cmap)

else:
try:
acq_func_identifier = acquisition.parse_acq_func_identifier(identifier=item)
except Exception:
raise ValueError("'item' must be either 'mean', 'error', or a valid acq func.")

acq_func, acq_func_meta = self.get_acquisition_function(identifier=acq_func_identifier, return_metadata=True)
a = acq_func(test_grid).detach().numpy()

self.viewer.add_image(data=a, name=f"{acq_func_identifier}", colormap=cmap)

self.viewer.dims.axis_labels = self.dofs.names

def tell(
self,
data: Optional[Mapping] = {},
Expand All @@ -142,7 +182,7 @@ def tell(
A dict keyed by the name of each objective, with a list of values for each objective.
append: bool
If `True`, will append new data to old data. If `False`, will replace old data with new data.
train_models: bool
train: bool
Whether to train the models on construction.
hypers:
A dict of hyperparameters for the model to assume a priori, instead of training.
Expand All @@ -168,147 +208,16 @@ def tell(

cached_hypers = obj.model.state_dict() if hasattr(obj, "model") else None

obj.model = self.construct_model(obj)
obj.model = self._construct_model(obj)

if len(obj.model.train_targets) >= 2:
t0 = ttime.monotonic()
self.train_model(obj.model, hypers=(None if train else cached_hypers))
self._train_model(obj.model, hypers=(None if train else cached_hypers))
if self.verbose:
print(f"trained model '{obj.name}' in {1e3*(ttime.monotonic() - t0):.00f} ms")

# TODO: should this be per objective?
self.construct_classifier()

def train_model(self, model, hypers=None, **kwargs):
"""Fit all of the agent's models. All kwargs are passed to `botorch.fit.fit_gpytorch_mll`."""
if hypers is not None:
model.load_state_dict(hypers)
else:
botorch.fit.fit_gpytorch_mll(gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model), **kwargs)
model.trained = True

def construct_model(self, obj, skew_dims=None):
"""
Construct an untrained model for an objective.
"""

skew_dims = skew_dims if skew_dims is not None else self.latent_dim_tuples

likelihood = gpytorch.likelihoods.GaussianLikelihood(
noise_constraint=gpytorch.constraints.Interval(
torch.tensor(1e-4).square(),
torch.tensor(1 / obj.min_snr).square(),
),
# noise_prior=gpytorch.priors.torch_priors.LogNormalPrior(loc=loc, scale=scale),
)

outcome_transform = botorch.models.transforms.outcome.Standardize(m=1) # , batch_shape=torch.Size((1,)))

train_inputs = self.train_inputs(active=True)
train_targets = self.train_targets(obj.name)

safe = ~(torch.isnan(train_inputs).any(axis=1) | torch.isnan(train_targets).any(axis=1))

model = models.LatentGP(
train_inputs=train_inputs[safe],
train_targets=train_targets[safe],
likelihood=likelihood,
skew_dims=skew_dims,
input_transform=self.input_transform,
outcome_transform=outcome_transform,
)

model.trained = False

return model

def construct_classifier(self, skew_dims=None):
skew_dims = skew_dims if skew_dims is not None else self.latent_dim_tuples

dirichlet_likelihood = gpytorch.likelihoods.DirichletClassificationLikelihood(
self.all_objectives_valid.long(), learn_additional_noise=True
)

self.classifier = models.LatentDirichletClassifier(
train_inputs=self.train_inputs(active=True),
train_targets=dirichlet_likelihood.transformed_targets.transpose(-1, -2).double(),
skew_dims=skew_dims,
likelihood=dirichlet_likelihood,
input_transform=self.input_transform,
)

self.train_model(self.classifier)
self.constraint = GenericDeterministicModel(f=lambda x: self.classifier.probabilities(x)[..., -1])

# def construct_model(self, obj, skew_dims=None, a_priori_hypers=None):
# '''
# Construct an untrained model for an objective.
# '''
# skew_dims = skew_dims if skew_dims is not None else self.latent_dim_tuples

# inputs = self.table.loc[:, self.dofs.subset(active=True).names].values.astype(float)

# for i, obj in enumerate(self.objectives):
# values = self.train_targets(i)
# values = np.where(self.all_objectives_valid, values, np.nan)

# train_index = ~np.isnan(values)

# if not train_index.sum() >= 2:
# raise ValueError("There must be at least two valid data points per objective!")

# train_inputs = torch.tensor(inputs[train_index], dtype=torch.double)
# train_values = torch.tensor(values[train_index], dtype=torch.double).unsqueeze(-1) # .unsqueeze(0)

# # for constructing the log normal noise prior
# # target_snr = 2e2
# # scale = 2e0
# # loc = np.log(1 / target_snr**2) + scale**2

# likelihood = gpytorch.likelihoods.GaussianLikelihood(
# noise_constraint=gpytorch.constraints.Interval(
# torch.tensor(1e-4).square(),
# torch.tensor(1 / obj.min_snr).square(),
# ),
# # noise_prior=gpytorch.priors.torch_priors.LogNormalPrior(loc=loc, scale=scale),
# )

# outcome_transform = botorch.models.transforms.outcome.Standardize(m=1) # , batch_shape=torch.Size((1,)))

# obj.model = models.LatentGP(
# train_inputs=train_inputs,
# train_targets=self.t,
# likelihood=likelihood,
# skew_dims=skew_dims,
# input_transform=self.input_transform,
# outcome_transform=outcome_transform,
# )

# dirichlet_likelihood = gpytorch.likelihoods.DirichletClassificationLikelihood(
# self.all_objectives_valid.long(), learn_additional_noise=True
# )

# self.classifier = models.LatentDirichletClassifier(
# train_inputs=torch.tensor(inputs).double(),
# train_targets=dirichlet_likelihood.transformed_targets.transpose(-1, -2).double(),
# skew_dims=skew_dims,
# likelihood=dirichlet_likelihood,
# input_transform=self._subset_input_transform(active=True),
# )

# if a_priori_hypers is not None:
# self._set_hypers(a_priori_hypers)
# else:
# self._train_models()
# # try:

# # except botorch.exceptions.errors.ModelFittingError:
# # if self.initialized:
# # self._set_hypers(cached_hypers)
# # else:
# # raise RuntimeError('Could not fit model on initialization!')

# self.constraint = GenericDeterministicModel(f=lambda x: self.classifier.probabilities(x)[..., -1])
self._construct_classifier()

def ask(self, acq_func_identifier="qei", n=1, route=True, sequential=True, upsample=1, **acq_func_kwargs):
"""Ask the agent for the best point to sample, given an acquisition function.
Expand Down Expand Up @@ -527,6 +436,67 @@ def learn(
metadata = new_table.to_dict(orient="list")
self.tell(x=x, y=y, metadata=metadata, append=append, train=train)

def _train_model(self, model, hypers=None, **kwargs):
"""Fit all of the agent's models. All kwargs are passed to `botorch.fit.fit_gpytorch_mll`."""
if hypers is not None:
model.load_state_dict(hypers)
else:
botorch.fit.fit_gpytorch_mll(gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model), **kwargs)
model.trained = True

def _construct_model(self, obj, skew_dims=None):
"""
Construct an untrained model for an objective.
"""

skew_dims = skew_dims if skew_dims is not None else self.latent_dim_tuples

likelihood = gpytorch.likelihoods.GaussianLikelihood(
noise_constraint=gpytorch.constraints.Interval(
torch.tensor(1e-4).square(),
torch.tensor(1 / obj.min_snr).square(),
),
# noise_prior=gpytorch.priors.torch_priors.LogNormalPrior(loc=loc, scale=scale),
)

outcome_transform = botorch.models.transforms.outcome.Standardize(m=1) # , batch_shape=torch.Size((1,)))

train_inputs = self.train_inputs(active=True)
train_targets = self.train_targets(obj.name)

safe = ~(torch.isnan(train_inputs).any(axis=1) | torch.isnan(train_targets).any(axis=1))

model = models.LatentGP(
train_inputs=train_inputs[safe],
train_targets=train_targets[safe],
likelihood=likelihood,
skew_dims=skew_dims,
input_transform=self.input_transform,
outcome_transform=outcome_transform,
)

model.trained = False

return model

def _construct_classifier(self, skew_dims=None):
skew_dims = skew_dims if skew_dims is not None else self.latent_dim_tuples

dirichlet_likelihood = gpytorch.likelihoods.DirichletClassificationLikelihood(
self.all_objectives_valid.long(), learn_additional_noise=True
)

self.classifier = models.LatentDirichletClassifier(
train_inputs=self.train_inputs(active=True),
train_targets=dirichlet_likelihood.transformed_targets.transpose(-1, -2).double(),
skew_dims=skew_dims,
likelihood=dirichlet_likelihood,
input_transform=self.input_transform,
)

self._train_model(self.classifier)
self.constraint = GenericDeterministicModel(f=lambda x: self.classifier.probabilities(x)[..., -1])

def get_acquisition_function(self, identifier, return_metadata=False):
"""Returns a BoTorch acquisition function for a given identifier. Acquisition functions can be
found in `agent.all_acq_funcs`.
Expand Down Expand Up @@ -630,19 +600,23 @@ def test_inputs_grid(self, max_inputs=MAX_TEST_INPUTS):
n_settable_acq_func_dofs = len(self.dofs.subset(active=True, read_only=False))
n_side_settable = int(np.power(max_inputs, n_settable_acq_func_dofs**-1))
n_sides = [1 if dof.read_only else n_side_settable for dof in self.dofs.subset(active=True)]
return torch.cat(
[
tensor.unsqueeze(-1)
for tensor in torch.meshgrid(
*[
torch.linspace(lower_limit, upper_limit, n_side)
for (lower_limit, upper_limit), n_side in zip(self.dofs.subset(active=True).limits, n_sides)
],
indexing="ij",
)
],
dim=-1,
).unsqueeze(-2)
return (
torch.cat(
[
tensor.unsqueeze(-1)
for tensor in torch.meshgrid(
*[
torch.linspace(lower_limit, upper_limit, n_side)
for (lower_limit, upper_limit), n_side in zip(self.dofs.subset(active=True).limits, n_sides)
],
indexing="ij",
)
],
dim=-1,
)
.unsqueeze(-2)
.double()
)

def test_inputs(self, n=MAX_TEST_INPUTS):
"""Returns a (n, 1, n_active_dof) grid of test_inputs"""
Expand Down Expand Up @@ -700,20 +674,31 @@ def save_data(self, filepath="./self_data.h5"):

self.table.to_hdf(filepath, key="table")

def forget(self, index, train=True):
"""
Make the agent forget some index of the data table.
def forget(self, last=None, index=None, train=True):
"""
self.table.drop(index=index, inplace=True)
self._construct_models(train=train)
Make the agent forget some data.

def forget_last_n(self, n, train=True):
"""
Make the agent forget the last `n` data points taken.
Parameters
----------
index :
An index of samples to forget about.
last : int
Forget the last n=last points.
"""
if n > len(self.table):
raise ValueError(f"Cannot forget {n} data points (only {len(self.table)} have been taken).")
self.forget(self.table.index.iloc[-n:], train=train)

if last is not None:
if last > len(self.table):
raise ValueError(f"Cannot forget last {last} data points (only {len(self.table)} samples have been taken).")
self.forget(index=self.table.index.values[-last:], train=train)

elif index is not None:
self.table.drop(index=index, inplace=True)
self._construct_all_models()
if train:
self._train_all_models()

else:
raise ValueError("Must supply either 'last' or 'index'.")

def sampler(self, n, d):
"""
Expand Down Expand Up @@ -761,7 +746,12 @@ def load_hypers(filepath):
hypers[model_key][param_key] = torch.tensor(np.atleast_1d(param_value[()]))
return hypers

def _train_models(self, **kwargs):
def _construct_all_models(self):
"""Construct a model for each objective."""
for obj in self.objectives:
obj.model = self._construct_model(obj)

def _train_all_models(self, **kwargs):
"""Fit all of the agent's models. All kwargs are passed to `botorch.fit.fit_gpytorch_mll`."""
t0 = ttime.monotonic()
for obj in self.objectives:
Expand Down
Loading