Skip to content

Commit

Permalink
fixed notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaswmorris committed Jan 16, 2024
1 parent aa7ebee commit ca18732
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 38 deletions.
2 changes: 1 addition & 1 deletion blop/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from . import utils # noqa F401
from ._version import get_versions
from .agent import Agent # noqa F401
from .agent import Agent # noqa F401
from .dofs import DOF # noqa F401
from .objectives import Objective # noqa F401

Expand Down
27 changes: 11 additions & 16 deletions blop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from ophyd import Signal

from . import utils
from .dofs import DOF, DOFList
from .objectives import Objective, ObjectiveList
from .bayesian import acquisition, models, plotting
from .bayesian.transforms import TargetingPosteriorTransform
from .digestion import default_digestion_function
from .dofs import DOF, DOFList
from .objectives import Objective, ObjectiveList
from .plans import default_acquisition_plan
from .bayesian.transforms import TargetingPosteriorTransform

warnings.filterwarnings("ignore", category=botorch.exceptions.warnings.InputDataWarning)

Expand All @@ -37,19 +37,17 @@
MAX_TEST_INPUTS = 2**11



def _validate_dofs_and_objs(dofs: DOFList, objs: ObjectiveList):

if len(dofs) == 0:
raise ValueError(f"You must supply at least one DOF.")
raise ValueError("You must supply at least one DOF.")

if len(objs) == 0:
raise ValueError(f"You must supply at least one objective.")
raise ValueError("You must supply at least one objective.")

for obj in objs:
for latent_group in obj.latent_groups:
for dof_name in latent_group:
if not dof_name in dofs.names:
if dof_name not in dofs.names:
raise ValueError(f"DOF name '{dof_name}' in latent group for objective '{obj.name}' does not exist.")


Expand Down Expand Up @@ -144,13 +142,11 @@ def __iter__(self):
yield self.dofs[index]

def __getattr__(self, attr):

acq_func_name = acquisition.parse_acq_func_identifier(attr)
if acq_func_name is not None:
return self.get_acquisition_function(identifier=acq_func_name)

raise AttributeError(f"DOFList object has no attribute named '{attr}'.")

raise AttributeError(f"DOFList object has no attribute named '{attr}'.")

def view(self, item: str = "mean", cmap: str = "turbo", max_inputs: int = MAX_TEST_INPUTS):
"""
Expand Down Expand Up @@ -513,7 +509,7 @@ def _construct_model(self, obj, skew_dims=None):
return model

def _construct_classifier(self, skew_dims=None):
skew_dims = [tuple([i]) for i in range(len(self.dofs))]
skew_dims = [tuple([i]) for i in range(len(self.dofs.subset(active=True)))]

train_inputs = self.train_inputs(active=True)
trusted = ~torch.isnan(train_inputs).any(axis=1)
Expand Down Expand Up @@ -654,7 +650,6 @@ def acquisition_function_bounds(self):
acq_func_upper_bounds = np.where(self.dofs.read_only, self.dofs.readback, self.dofs.search_upper_bounds)

return torch.tensor(np.vstack([acq_func_lower_bounds, acq_func_upper_bounds]), dtype=torch.double)


def latent_dim_tuples(self, obj_index=None):
"""
Expand All @@ -663,8 +658,8 @@ def latent_dim_tuples(self, obj_index=None):
"""

if obj_index is None:
return {obj.name:self.latent_dim_tuples(obj_index=obj.name) for obj in self.objectives}
return {obj.name: self.latent_dim_tuples(obj_index=obj.name) for obj in self.objectives}

obj = self.objectives[obj_index]

latent_group_index = {}
Expand All @@ -673,7 +668,7 @@ def latent_dim_tuples(self, obj_index=None):
for group_index, latent_group in enumerate(obj.latent_groups):
if dof.name in latent_group:
latent_group_index[dof.name] = group_index

u, uinv = np.unique(list(latent_group_index.values()), return_inverse=True)
return [tuple(np.where(uinv == i)[0]) for i in range(len(u))]

Expand Down
3 changes: 2 additions & 1 deletion blop/bayesian/acquisition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def parse_acq_func_identifier(identifier):
return acq_func_name
return None


def get_acquisition_function(agent, identifier="qei", return_metadata=True, verbose=False, **acq_func_kwargs):
"""Generates an acquisition function from a supplied identifier. A list of acquisition functions and
their identifiers can be found at `agent.all_acq_funcs`.
Expand All @@ -31,7 +32,7 @@ def get_acquisition_function(agent, identifier="qei", return_metadata=True, verb
acq_func_name = parse_acq_func_identifier(identifier)
if acq_func_name is None:
raise ValueError(f'Unrecognized acquisition function identifier "{identifier}".')

acq_func_config = config["upper_confidence_bound"]

if config[acq_func_name]["multitask_only"] and (len(agent.objectives) == 1):
Expand Down
2 changes: 1 addition & 1 deletion blop/bayesian/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
)

if self.n_skew_entries > 0:
skew_entries_constraint = gpytorch.constraints.Interval(-2*np.pi, 2*np.pi)
skew_entries_constraint = gpytorch.constraints.Interval(-2 * np.pi, 2 * np.pi)
skew_entries_initial = torch.zeros((self.num_outputs, self.n_skew_entries), dtype=torch.float64)
self.register_parameter(name="raw_skew_entries", parameter=torch.nn.Parameter(skew_entries_initial))
self.register_constraint(param_name="raw_skew_entries", constraint=skew_entries_constraint)
Expand Down
1 change: 0 additions & 1 deletion blop/bayesian/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def _plot_objs_many_dofs(agent, axes=(0, 1), shading="nearest", cmap=DEFAULT_COL
test_y = test_inputs[..., 0, axes[1]].detach().squeeze().numpy()

for obj_index, obj in enumerate(agent.objectives):

targets = agent.train_targets(obj.name).squeeze(-1).numpy()

obj_vmin, obj_vmax = np.nanpercentile(targets, q=[1, 99])
Expand Down
2 changes: 1 addition & 1 deletion blop/dofs.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def search_lower_bound(self):
@property
def search_upper_bound(self):
return float(self.search_bounds[1])

@property
def trust_lower_bound(self):
return float(self.trust_bounds[0])
Expand Down
4 changes: 2 additions & 2 deletions blop/objectives.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Tuple, List, Union
from typing import List, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -134,7 +134,7 @@ def __getattr__(self, attr):
if attr in self.names:
return self.__getitem__(attr)
else:
raise AttributeError(f'No attribute named {attr}')
raise AttributeError(f"No attribute named {attr}")

def __getitem__(self, i):
if type(i) is int:
Expand Down
11 changes: 5 additions & 6 deletions docs/source/tutorials/himmelblau.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@
"metadata": {},
"outputs": [],
"source": [
"from blop.bayesian import DOF\n",
"from blop import DOF\n",
"\n",
"dofs = [\n",
" DOF(name=\"x1\", limits=(-6, 6)),\n",
" DOF(name=\"x2\", limits=(-6, 6)),\n",
" DOF(name=\"x1\", search_bounds=(-6, 6)),\n",
" DOF(name=\"x2\", search_bounds=(-6, 6)),\n",
"]"
]
},
Expand All @@ -92,7 +92,7 @@
"metadata": {},
"outputs": [],
"source": [
"from blop.bayesian import Objective\n",
"from blop import Objective\n",
"\n",
"objectives = [Objective(name=\"himmelblau\", description=\"Himmeblau's function\", target=\"min\")]"
]
Expand Down Expand Up @@ -140,8 +140,7 @@
},
"outputs": [],
"source": [
"from blop.bayesian import Agent\n",
"\n",
"from blop import Agent\n",
"\n",
"agent = Agent(\n",
" dofs=dofs,\n",
Expand Down
18 changes: 14 additions & 4 deletions docs/source/tutorials/hyperparameters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@
"\n",
"%run -i $prepare_re_env.__file__ --db-type=temp\n",
"\n",
"from blop.bayesian import DOF, Objective, Agent\n",
"from blop import DOF, Objective, Agent\n",
"\n",
"dofs = [\n",
" DOF(name=\"x1\", limits=(-6, 6)),\n",
" DOF(name=\"x2\", limits=(-6, 6)),\n",
" DOF(name=\"x1\", search_bounds=(-6, 6)),\n",
" DOF(name=\"x2\", search_bounds=(-6, 6)),\n",
"]\n",
"\n",
"objectives = [\n",
Expand Down Expand Up @@ -127,6 +127,16 @@
"RE(agent.learn(\"qei\", n=4, iterations=4))\n",
"agent.plot_objectives()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b70eaf9b",
"metadata": {},
"outputs": [],
"source": [
"agent.best"
]
}
],
"metadata": {
Expand All @@ -145,7 +155,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.10.0"
},
"vscode": {
"interpreter": {
Expand Down
21 changes: 16 additions & 5 deletions docs/source/tutorials/passive-dofs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@
"outputs": [],
"source": [
"from blop.utils import functions\n",
"from blop.bayesian import DOF, Agent, BrownianMotion, Objective\n",
"from blop import DOF, Agent, Objective\n",
"from blop.dofs import BrownianMotion\n",
"\n",
"\n",
"dofs = [\n",
" DOF(name=\"x1\", limits=(-5.0, 5.0)),\n",
" DOF(name=\"x2\", limits=(-5.0, 5.0)),\n",
" DOF(name=\"x3\", limits=(-5.0, 5.0), active=False),\n",
" DOF(name=\"x1\", search_bounds=(-5.0, 5.0)),\n",
" DOF(name=\"x2\", search_bounds=(-5.0, 5.0)),\n",
" DOF(name=\"x3\", search_bounds=(-5.0, 5.0), active=False),\n",
" DOF(device=BrownianMotion(name=\"brownian1\"), read_only=True),\n",
" DOF(device=BrownianMotion(name=\"brownian2\"), read_only=True, active=False),\n",
"]\n",
Expand Down Expand Up @@ -75,6 +76,16 @@
"source": [
"agent.plot_objectives()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "312610b1",
"metadata": {},
"outputs": [],
"source": [
"agent.latent_dim_tuples()"
]
}
],
"metadata": {
Expand All @@ -93,7 +104,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.10.0"
},
"vscode": {
"interpreter": {
Expand Down

0 comments on commit ca18732

Please sign in to comment.