Skip to content

Commit

Permalink
better test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Jan 8, 2024
1 parent 1f45119 commit 24a1886
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 23 deletions.
42 changes: 29 additions & 13 deletions bloptools/bayesian/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def tell(
A dict keyed by the name of each objective, with a list of values for each objective.
append: bool
If `True`, will append new data to old data. If `False`, will replace old data with new data.
_train_models: bool
train: bool
Whether to train the models on construction.
hypers:
A dict of hyperparameters for the model to assume a priori, instead of training.
Expand Down Expand Up @@ -674,20 +674,31 @@ def save_data(self, filepath="./self_data.h5"):

self.table.to_hdf(filepath, key="table")

def forget(self, index, train=True):
"""
Make the agent forget some index of the data table.
def forget(self, last=None, index=None, train=True):
"""
self.table.drop(index=index, inplace=True)
self.__construct_models(train=train)
Make the agent forget some data.
def forget_last_n(self, n, train=True):
"""
Make the agent forget the last `n` data points taken.
Parameters
----------
index :
An index of samples to forget about.
last : int
Forget the last n=last points.
"""
if n > len(self.table):
raise ValueError(f"Cannot forget {n} data points (only {len(self.table)} have been taken).")
self.forget(self.table.index.iloc[-n:], train=train)

if last is not None:
if last > len(self.table):
raise ValueError(f"Cannot forget last {last} data points (only {len(self.table)} samples have been taken).")
self.forget(index=self.table.index.values[-last:], train=train)

elif index is not None:
self.table.drop(index=index, inplace=True)
self._construct_all_models()
if train:
self._train_all_models()

else:
raise ValueError("Must supply either 'last' or 'index'.")

def sampler(self, n, d):
"""
Expand Down Expand Up @@ -735,7 +746,12 @@ def load_hypers(filepath):
hypers[model_key][param_key] = torch.tensor(np.atleast_1d(param_value[()]))
return hypers

def __train_models(self, **kwargs):
def _construct_all_models(self):
"""Construct a model for each objective."""
for obj in self.objectives:
obj.model = self._construct_model(obj)

def _train_all_models(self, **kwargs):
"""Fit all of the agent's models. All kwargs are passed to `botorch.fit.fit_gpytorch_mll`."""
t0 = ttime.monotonic()
for obj in self.objectives:
Expand Down
4 changes: 2 additions & 2 deletions bloptools/tests/test_acq_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

@pytest.mark.parametrize("acq_func", ["ei", "pi", "em", "ucb"])
def test_analytic_acq_funcs_single_objective(agent, RE, acq_func):
RE(agent.learn("qr", n=16))
RE(agent.learn("qr", n=4))
RE(agent.learn(acq_func, n=1))


@pytest.mark.parametrize("acq_func", ["qei", "qpi", "qem", "qucb"])
def test_monte_carlo_acq_funcs_single_objective(agent, RE, acq_func):
RE(agent.learn("qr", n=16))
RE(agent.learn("qr", n=4))
RE(agent.learn(acq_func, n=4))


Expand Down
10 changes: 10 additions & 0 deletions bloptools/tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest # noqa F401


def test_agent(agent, RE):
RE(agent.learn("qr", n=4))


def test_forget(agent, RE):
RE(agent.learn("qr", n=4))
agent.forget(last=2)
2 changes: 1 addition & 1 deletion bloptools/tests/test_napari.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@

@pytest.mark.parametrize("item", ["mean", "error", "qei"])
def test_napari_viewer(agent, RE, item):
RE(agent.learn("qr", n=16))
RE(agent.learn("qr", n=4))
agent.view(item)
10 changes: 3 additions & 7 deletions bloptools/tests/test_read_write.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
import pytest # noqa F401


def test_agent(agent, RE):
RE(agent.learn("qr", n=16))


def test_agent_save_load_data(agent, RE):
RE(agent.learn("qr", n=16))
RE(agent.learn("qr", n=4))
agent.save_data("/tmp/test_save_data.h5")
agent.reset()
agent.load_data(data_file="/tmp/test_save_data.h5")
RE(agent.learn("qr", n=16))
RE(agent.learn("qr", n=4))


def test_agent_save_load_hypers(agent, RE):
RE(agent.learn("qr", n=16))
RE(agent.learn("qr", n=4))
agent.save_hypers("/tmp/test_save_hypers.h5")
agent.reset()
RE(agent.learn("qr", n=16, hypers_file="/tmp/test_save_hypers.h5"))

0 comments on commit 24a1886

Please sign in to comment.