Skip to content

Commit

Permalink
wrap up als work
Browse files Browse the repository at this point in the history
  • Loading branch information
taxe10 committed Nov 20, 2023
1 parent 34323e7 commit 6364b6d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
12 changes: 10 additions & 2 deletions bloptools/bayesian/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def _construct_models(self, train=True, skew_dims=None, a_priori_hypers=None):
inputs = self.table.loc[:, self.dofs.subset(active=True).names].values.astype(float)

for i, obj in enumerate(self.objectives):

values = self.get_objective_targets(i)
values = np.where(self.all_objectives_valid, values, np.nan)

train_index = ~np.isnan(values)

Expand Down Expand Up @@ -218,7 +220,7 @@ def _construct_models(self, train=True, skew_dims=None, a_priori_hypers=None):

self.constraint = GenericDeterministicModel(f=lambda x: self.classifier.probabilities(x)[..., -1])

def ask(self, acq_func_identifier="qei", n=1, route=True, sequential=True):
def ask(self, acq_func_identifier="qei", n=1, route=True, sequential=True, upsample=1):
"""Ask the agent for the best point to sample, given an acquisition function.
Parameters
Expand Down Expand Up @@ -306,6 +308,11 @@ def ask(self, acq_func_identifier="qei", n=1, route=True, sequential=True):
routing_index = utils.route(self.dofs.subset(active=True, read_only=False).readback, acq_points)
acq_points = acq_points[routing_index]

if upsample > 1:
idx = np.arange(len(acq_points))
upsampled_idx = np.linspace(0, len(idx) - 1, upsample * len(idx) - 1)
acq_points = sp.interpolate.interp1d(idx, acq_points, axis=0)(upsampled_idx)

return acq_points, acq_func_meta

def acquire(self, acquisition_inputs):
Expand Down Expand Up @@ -409,7 +416,7 @@ def learn(
for i in range(iterations):
print(f"running iteration {i + 1} / {iterations}")
for single_acq_func in np.atleast_1d(acq_func):
acq_points, acq_func_meta = self.ask(n=n, acq_func_identifier=single_acq_func)
acq_points, acq_func_meta = self.ask(n=n, acq_func_identifier=single_acq_func, upsample=upsample)
new_table = yield from self.acquire(acq_points)
new_table.loc[:, "acq_func"] = acq_func_meta["name"]

Expand Down Expand Up @@ -746,6 +753,7 @@ def go_to(self, **positions):
def go_to_best(self):
"""Go to the position of the best input seen so far."""
yield from self.go_to(**self.best_inputs)


def plot_objectives(self, axes: Tuple = (0, 1), **kwargs):
"""Plot the sampled objectives
Expand Down
5 changes: 4 additions & 1 deletion bloptools/bayesian/dofs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ def __init__(self, dofs: list = []):
self.dofs = dofs

def __getitem__(self, i):
return self.dofs[i]
if type(i) is int:
return self.dofs[i]
elif type(i) is str:
return self.dofs[self.names.index(i)]

def __len__(self):
return len(self.dofs)
Expand Down

0 comments on commit 6364b6d

Please sign in to comment.