From 86cfc592196fe71e5a1c0e8f0b9eabda3abf83e8 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Thu, 14 Apr 2022 08:51:35 -0700 Subject: [PATCH] Refactor Input Transforms Summary: Currently, we apply the input transforms in `train` mode at the `forward` call, and in `eval` model at the `posterior` call. We also use a `transform_train_inputs` call at the `eval/train` calls to make sure that at `eval` time the `train_inputs` are stored as transformed (since they don't pass through `posterior`). This design supports `ExactGP` models, and supports specifying where to apply which input transform via the flags (so that one-to-many transforms are only applied to test inputs). However, this does not work great with Approximate GP models, since this setup does not transform the inducing points at `eval` time. This refactor splits out one-to-many transforms as `InputAugmentationTransform`, allowing us to revert to simply applying the `transform_inputs` in the `forward` pass (at all times). We still need to apply one-to-many transforms (now called `InputAugmentationTransform`) in `posterior`, so we introduce an `augment_inputs` method. (Inspired by the public-private APIs of Ax) In order to minimize the transform related knowledge expected from developers, this introduces a `Model.forward` call that applies `transform_inputs` and calls `self._forward`. `._forward` is the usual `forward` call that computes the prior, except that it no longer has to worry about transforms. Similarly, for the `posterior`, this makes `Model.posterior` into a simple wrapper around `Model._posterior`, which applies the `augment_inputs` call and the `posterior_transform`. Again, the `._posterior` becomes the usual posterior call that no longer has to worry about the input or posterior transforms (still has to deal with the outcome transform in the current implementation, though we can fix this by bringing back the `fantasize` flag). This diff presents a minimal implementation around the `SingleTaskGP` model. Differential Revision: D35129407 fbshipit-source-id: 0a8ab840774bcd281f50925314d04725b453a7c8 --- .../multi_objective/monte_carlo.py | 2 +- botorch/models/gp_regression.py | 51 +-- botorch/models/gpytorch.py | 29 +- botorch/models/model.py | 154 ++++---- botorch/models/model_list_gp_regression.py | 10 - botorch/models/transforms/input.py | 366 +----------------- .../models/transforms/input_augmentation.py | 239 ++++++++++++ botorch/models/utils.py | 6 - 8 files changed, 369 insertions(+), 488 deletions(-) create mode 100644 botorch/models/transforms/input_augmentation.py diff --git a/botorch/acquisition/multi_objective/monte_carlo.py b/botorch/acquisition/multi_objective/monte_carlo.py index f96e34771e..df45143446 100644 --- a/botorch/acquisition/multi_objective/monte_carlo.py +++ b/botorch/acquisition/multi_objective/monte_carlo.py @@ -42,7 +42,7 @@ from botorch.exceptions.errors import UnsupportedError from botorch.exceptions.warnings import BotorchWarning from botorch.models.model import Model -from botorch.models.transforms.input import InputPerturbation +from botorch.models.transforms.input_augmentation import InputPerturbation from botorch.posteriors import DeterministicPosterior from botorch.posteriors.posterior import Posterior from botorch.sampling.samplers import MCSampler, SobolQMCNormalSampler diff --git a/botorch/models/gp_regression.py b/botorch/models/gp_regression.py index a637feb4c6..4c9060aa70 100644 --- a/botorch/models/gp_regression.py +++ b/botorch/models/gp_regression.py @@ -16,8 +16,9 @@ from botorch import settings from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel from botorch.models.transforms.input import InputTransform +from botorch.models.transforms.input_augmentation import InputAugmentationTransform from botorch.models.transforms.outcome import Log, OutcomeTransform -from botorch.models.utils import fantasize as fantasize_flag, validate_input_scaling +from botorch.models.utils import validate_input_scaling from botorch.sampling.samplers import MCSampler from botorch.utils.containers import TrainingData from gpytorch.constraints.constraints import GreaterThan @@ -70,6 +71,7 @@ def __init__( mean_module: Optional[Mean] = None, outcome_transform: Optional[OutcomeTransform] = None, input_transform: Optional[InputTransform] = None, + input_augmentation_transform: Optional[InputAugmentationTransform] = None, ) -> None: r"""A single-task exact GP model. @@ -88,6 +90,8 @@ def __init__( `.posterior` on the model will be on the original scale). input_transform: An input transform that is applied in the model's forward pass. + input_augmentation_transform: An input augmentation transform that is + applied in the `posterior` call. Example: >>> train_X = torch.rand(20, 2) @@ -148,11 +152,11 @@ def __init__( self.outcome_transform = outcome_transform if input_transform is not None: self.input_transform = input_transform + if input_augmentation_transform is not None: + self.input_augmentation_transform = input_augmentation_transform self.to(train_X) - def forward(self, x: Tensor) -> MultivariateNormal: - if self.training: - x = self.transform_inputs(x) + def _forward(self, x: Tensor) -> MultivariateNormal: mean_x = self.mean_module(x) covar_x = self.covar_module(x) return MultivariateNormal(mean_x, covar_x) @@ -191,6 +195,7 @@ def __init__( mean_module: Optional[Mean] = None, outcome_transform: Optional[OutcomeTransform] = None, input_transform: Optional[InputTransform] = None, + input_augmentation_transform: Optional[InputAugmentationTransform] = None, **kwargs: Any, ) -> None: r"""A single-task exact GP model using fixed noise levels. @@ -210,6 +215,8 @@ def __init__( `.posterior` on the model will be on the original scale). input_transform: An input transfrom that is applied in the model's forward pass. + input_augmentation_transform: An input augmentation transform that is + applied in the `posterior` call. Example: >>> train_X = torch.rand(20, 2) @@ -262,7 +269,8 @@ def __init__( self.input_transform = input_transform if outcome_transform is not None: self.outcome_transform = outcome_transform - + if input_augmentation_transform is not None: + self.input_augmentation_transform = input_augmentation_transform self.to(train_X) def fantasize( @@ -298,24 +306,19 @@ def fantasize( The constructed fantasy model. """ propagate_grads = kwargs.pop("propagate_grads", False) - with fantasize_flag(): - with settings.propagate_grads(propagate_grads): - post_X = self.posterior( - X, observation_noise=observation_noise, **kwargs - ) - Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m - # Use the mean of the previous noise values (TODO: be smarter here). - # noise should be batch_shape x q x m when X is batch_shape x q x d, and - # Y_fantasized is num_fantasies x batch_shape x q x m. - noise_shape = Y_fantasized.shape[1:] - noise = self.likelihood.noise.mean().expand(noise_shape) - return self.condition_on_observations( - X=self.transform_inputs(X), Y=Y_fantasized, noise=noise - ) + with settings.propagate_grads(propagate_grads): + post_X = self._posterior(X, observation_noise=observation_noise, **kwargs) + Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m + # Use the mean of the previous noise values (TODO: be smarter here). + # noise should be batch_shape x q x m when X is batch_shape x q x d, and + # Y_fantasized is num_fantasies x batch_shape x q x m. + noise_shape = Y_fantasized.shape[1:] + noise = self.likelihood.noise.mean().expand(noise_shape) + return self.condition_on_observations( + X=self.transform_inputs(X), Y=Y_fantasized, noise=noise + ) - def forward(self, x: Tensor) -> MultivariateNormal: - if self.training: - x = self.transform_inputs(x) + def _forward(self, x: Tensor) -> MultivariateNormal: mean_x = self.mean_module(x) covar_x = self.covar_module(x) return MultivariateNormal(mean_x, covar_x) @@ -370,6 +373,7 @@ def __init__( train_Yvar: Tensor, outcome_transform: Optional[OutcomeTransform] = None, input_transform: Optional[InputTransform] = None, + input_augmentation_transform: Optional[InputAugmentationTransform] = None, ) -> None: r"""A single-task exact GP model using a heteroskedastic noise model. @@ -386,6 +390,8 @@ def __init__( variances, which will happen after this transform is applied. input_transform: An input transfrom that is applied in the model's forward pass. + input_augmentation_transform: An input augmentation transform that is + applied in the `posterior` call. Example: >>> train_X = torch.rand(20, 2) @@ -419,6 +425,7 @@ def __init__( train_Y=train_Y, likelihood=likelihood, input_transform=input_transform, + input_augmentation_transform=input_augmentation_transform, ) self.register_added_loss_term("noise_added_loss") self.update_added_loss_term( diff --git a/botorch/models/gpytorch.py b/botorch/models/gpytorch.py index e9577b05c5..1191612758 100644 --- a/botorch/models/gpytorch.py +++ b/botorch/models/gpytorch.py @@ -117,11 +117,10 @@ def num_outputs(self) -> int: r"""The number of outputs of the model.""" return self._num_outputs - def posterior( + def _posterior( self, X: Tensor, observation_noise: Union[bool, Tensor] = False, - posterior_transform: Optional[PosteriorTransform] = None, **kwargs: Any, ) -> GPyTorchPosterior: r"""Computes the posterior over model outputs at the provided points. @@ -133,7 +132,6 @@ def posterior( observation_noise: If True, add the observation noise from the likelihood to the posterior. If a Tensor, use it directly as the observation noise (must be of shape `(batch_shape) x q`). - posterior_transform: An optional PosteriorTransform. Returns: A `GPyTorchPosterior` object, representing a batch of `b` joint @@ -141,9 +139,6 @@ def posterior( specified. """ self.eval() # make sure model is in eval mode - # input transforms are applied at `posterior` in `eval` mode, and at - # `model.forward()` at the training time - X = self.transform_inputs(X) with gpt_posterior_settings(): mvn = self(X) if observation_noise is not False: @@ -158,8 +153,6 @@ def posterior( posterior = GPyTorchPosterior(mvn=mvn) if hasattr(self, "outcome_transform"): posterior = self.outcome_transform.untransform_posterior(posterior) - if posterior_transform is not None: - return posterior_transform(posterior) return posterior def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Model: @@ -301,12 +294,11 @@ def _transform_tensor_args( ) return X, Y.squeeze(-1), None if Yvar is None else Yvar.squeeze(-1) - def posterior( + def _posterior( self, X: Tensor, output_indices: Optional[List[int]] = None, observation_noise: Union[bool, Tensor] = False, - posterior_transform: Optional[PosteriorTransform] = None, **kwargs: Any, ) -> GPyTorchPosterior: r"""Computes the posterior over model outputs at the provided points. @@ -323,7 +315,6 @@ def posterior( observation_noise: If True, add the observation noise from the likelihood to the posterior. If a Tensor, use it directly as the observation noise (must be of shape `(batch_shape) x q x m`). - posterior_transform: An optional PosteriorTransform. Returns: A `GPyTorchPosterior` object, representing `batch_shape` joint @@ -331,9 +322,6 @@ def posterior( `output_indices` each. Includes observation noise if specified. """ self.eval() # make sure model is in eval mode - # input transforms are applied at `posterior` in `eval` mode, and at - # `model.forward()` at the training time - X = self.transform_inputs(X) with gpt_posterior_settings(): # insert a dimension for the output dimension if self._num_outputs > 1: @@ -369,8 +357,6 @@ def posterior( posterior = GPyTorchPosterior(mvn=mvn) if hasattr(self, "outcome_transform"): posterior = self.outcome_transform.untransform_posterior(posterior) - if posterior_transform is not None: - return posterior_transform(posterior) return posterior def condition_on_observations( @@ -549,6 +535,8 @@ def posterior( by `output_indices` each. Includes measurement noise if `observation_noise` is specified. """ + # TODO: Not sure if this needs special handling or is good with a `_`. + # Leaving untouched for now. self.eval() # make sure model is in eval mode # input transforms are applied at `posterior` in `eval` mode, and at # `model.forward()` at the training time @@ -622,12 +610,11 @@ class MultiTaskGPyTorchModel(GPyTorchModel, ABC): "long-format" multi-task GP in the style of `MultiTaskGP`. """ - def posterior( + def _posterior( self, X: Tensor, output_indices: Optional[List[int]] = None, observation_noise: Union[bool, Tensor] = False, - posterior_transform: Optional[PosteriorTransform] = None, **kwargs: Any, ) -> GPyTorchPosterior: r"""Computes the posterior over model outputs at the provided points. @@ -644,7 +631,6 @@ def posterior( observation_noise: If True, add observation noise from the respective likelihoods. If a Tensor, specifies the observation noise levels to add. - posterior_transform: An optional PosteriorTransform. Returns: A `GPyTorchPosterior` object, representing `batch_shape` joint @@ -663,9 +649,6 @@ def posterior( X_full = _make_X_full(X=X, output_indices=output_indices, tf=self._task_feature) self.eval() # make sure model is in eval mode - # input transforms are applied at `posterior` in `eval` mode, and at - # `model.forward()` at the training time - X_full = self.transform_inputs(X_full) with gpt_posterior_settings(): mvn = self(X_full) if observation_noise is not False: @@ -685,6 +668,4 @@ def posterior( posterior = GPyTorchPosterior(mvn=mtmvn) if hasattr(self, "outcome_transform"): posterior = self.outcome_transform.untransform_posterior(posterior) - if posterior_transform is not None: - return posterior_transform(posterior) return posterior diff --git a/botorch/models/model.py b/botorch/models/model.py index cf27ec4864..748c513930 100644 --- a/botorch/models/model.py +++ b/botorch/models/model.py @@ -10,7 +10,6 @@ from __future__ import annotations -import warnings from abc import ABC, abstractmethod from collections import defaultdict from copy import deepcopy @@ -19,32 +18,48 @@ import numpy as np import torch from botorch import settings -from botorch.models.utils import fantasize as fantasize_flag from botorch.posteriors import Posterior, PosteriorList from botorch.posteriors.fully_bayesian import FullyBayesianPosteriorList from botorch.sampling.samplers import MCSampler from botorch.utils.containers import TrainingData from botorch.utils.transforms import is_fully_bayesian +from gpytorch.distributions import MultivariateNormal from torch import Tensor from torch.nn import Module, ModuleList class Model(Module, ABC): - r"""Abstract base class for BoTorch models. + r"""Abstract base class for BoTorch models.""" - Args: - _has_transformed_inputs: A boolean denoting whether `train_inputs` are currently - stored as transformed or not. - _original_train_inputs: A Tensor storing the original train inputs for use in - `_revert_to_original_inputs`. Note that this is necessary since - transform / untransform cycle introduces numerical errors which lead - to upstream errors during training. - """ + def forward(self, x: Tensor) -> MultivariateNormal: + r"""Transforms the inputs and computes the prior over model outputs + at the provided points. - _has_transformed_inputs: bool = False - _original_train_inputs: Optional[Tensor] = None + Args: + x: A `b x q x d`-dim Tensor of inputs, where `d` is the dimension of + the feature space, `q` is the number of points considered jointly, + and `b` is the batch dimension. + + Returns: + A MultivariateNormal object denoting the prior distribution. + """ + x = self.transform_inputs(x) + return self._forward(x) @abstractmethod + def _forward(self, x: Tensor) -> MultivariateNormal: + r"""Computes the prior over model outputs at the provided points. + + Args: + x: A `b x q x d`-dim Tensor of inputs, where `d` is the dimension of + the feature space, `q` is the number of points considered jointly, + and `b` is the batch dimension. + + Returns: + A MultivariateNormal object denoting the prior distribution. + """ + pass # pragma: no cover + def posterior( self, X: Tensor, @@ -53,11 +68,8 @@ def posterior( posterior_transform: Optional[Callable[[Posterior], Posterior]] = None, **kwargs: Any, ) -> Posterior: - r"""Computes the posterior over model outputs at the provided points. - - Note: The input transforms should be applied here using - `self.transform_inputs(X)` after the `self.eval()` call and before - any `model.forward` or `model.likelihood` calls. + r"""Augments the inputs, if needed, and computes the posterior over model + outputs at the provided points. Args: X: A `b x q x d`-dim Tensor, where `d` is the dimension of the @@ -71,6 +83,42 @@ def posterior( observation_noise: If True, add observation noise to the posterior. posterior_transform: An optional PosteriorTransform. + Returns: + A `Posterior` object, representing a batch of `b` joint distributions + over `q` points and `m` outputs each. + """ + X = self.augment_inputs(X) + posterior = self._posterior( + X=X, + output_indices=output_indices, + observation_noise=observation_noise, + kwargs=kwargs, + ) + if posterior_transform is not None: + posterior = posterior_transform(posterior) + return posterior + + @abstractmethod + def _posterior( + self, + X: Tensor, + output_indices: Optional[List[int]] = None, + observation_noise: bool = False, + **kwargs: Any, + ) -> Posterior: + r"""Computes the posterior over model outputs at the provided points. + + Args: + X: A `b x q x d`-dim Tensor, where `d` is the dimension of the + feature space, `q` is the number of points considered jointly, + and `b` is the batch dimension. + output_indices: A list of indices, corresponding to the outputs over + which to compute the posterior (if the model is multi-output). + Can be used to speed up computation if only a subset of the + model's outputs are required for optimization. If omitted, + computes the posterior over all model outputs. + observation_noise: If True, add observation noise to the posterior. + Returns: A `Posterior` object, representing a batch of `b` joint distributions over `q` points and `m` outputs each. @@ -161,13 +209,12 @@ def fantasize( The constructed fantasy model. """ propagate_grads = kwargs.pop("propagate_grads", False) - with fantasize_flag(): - with settings.propagate_grads(propagate_grads): - post_X = self.posterior(X, observation_noise=observation_noise) - Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m - return self.condition_on_observations( - X=self.transform_inputs(X), Y=Y_fantasized, **kwargs - ) + with settings.propagate_grads(propagate_grads): + post_X = self._posterior(X, observation_noise=observation_noise) + Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m + return self.condition_on_observations( + X=self.transform_inputs(X), Y=Y_fantasized, **kwargs + ) @classmethod def construct_inputs( @@ -186,11 +233,11 @@ def transform_inputs( r"""Transform inputs. Args: - X: A tensor of inputs + X: A `b x q x d`-dim tensor of inputs. input_transform: A Module that performs the input transformation. Returns: - A tensor of transformed inputs + A `b x q x d`-dim tensor of transformed inputs. """ if input_transform is not None: input_transform.to(X) @@ -200,49 +247,22 @@ def transform_inputs( except AttributeError: return X - def _set_transformed_inputs(self) -> None: - r"""Update training inputs with transformed inputs.""" - if hasattr(self, "input_transform") and not self._has_transformed_inputs: - if hasattr(self, "train_inputs"): - self._original_train_inputs = self.train_inputs[0] - with torch.no_grad(): - X_tf = self.input_transform.preprocess_transform( - self.train_inputs[0] - ) - self.set_train_data(X_tf, strict=False) - self._has_transformed_inputs = True - else: - warnings.warn( - "Could not update `train_inputs` with transformed inputs " - f"since {self.__class__.__name__} does not have a `train_inputs` " - "attribute. Make sure that the `input_transform` is applied to " - "both the train inputs and test inputs.", - RuntimeWarning, - ) - - def _revert_to_original_inputs(self) -> None: - r"""Revert training inputs back to original.""" - if hasattr(self, "input_transform") and self._has_transformed_inputs: - self.set_train_data(self._original_train_inputs, strict=False) - self._has_transformed_inputs = False - - def eval(self) -> Model: - r"""Puts the model in `eval` mode and sets the transformed inputs.""" - self._set_transformed_inputs() - return super().eval() - - def train(self, mode: bool = True) -> Model: - r"""Puts the model in `train` mode and reverts to the original inputs. + def augment_inputs( + self, + X: Tensor, + ) -> Tensor: + r"""Applies the input augmentation transform, if any. Args: - mode: A boolean denoting whether to put in `train` or `eval` mode. - If `False`, model is put in `eval` mode. + X: A `b x q x d`-dim tensor of inputs. + + Returns: + A `b x q' x d'`-dim tensor of augmented inputs. """ - if mode: - self._revert_to_original_inputs() - else: - self._set_transformed_inputs() - return super().train(mode=mode) + try: + return self.input_augmentation_transform(X) + except AttributeError: + return X class ModelList(Model): diff --git a/botorch/models/model_list_gp_regression.py b/botorch/models/model_list_gp_regression.py index c7d40700fb..d37d81fcea 100644 --- a/botorch/models/model_list_gp_regression.py +++ b/botorch/models/model_list_gp_regression.py @@ -102,13 +102,3 @@ def subset_output(self, idcs: List[int]) -> ModelListGP: The current model, subset to the specified output indices. """ return self.__class__(*[deepcopy(self.models[i]) for i in idcs]) - - def _set_transformed_inputs(self) -> None: - r"""Update training inputs with transformed inputs.""" - for m in self.models: - m._set_transformed_inputs() - - def _revert_to_original_inputs(self) -> None: - r"""Revert training inputs back to original.""" - for m in self.models: - m._revert_to_original_inputs() diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index 40a10bc30c..97bfcd7b59 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -18,12 +18,11 @@ from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Callable, List, Optional, Union +from typing import List, Optional, Union import torch from botorch.exceptions.errors import BotorchTensorDimensionError from botorch.models.transforms.utils import expand_and_copy_tensor -from botorch.models.utils import fantasize from botorch.utils.rounding import approximate_round from gpytorch import Module as GPyTorchModule from gpytorch.constraints import GreaterThan @@ -39,20 +38,8 @@ class InputTransform(ABC): Note: Input transforms must inherit from `torch.nn.Module`. This is deferred to the subclasses to avoid any potential conflict between `gpytorch.module.Module` and `torch.nn.Module` in `Warp`. - - Properties: - transform_on_train: A boolean indicating whether to apply the - transform in train() mode. - transform_on_eval: A boolean indicating whether to apply the - transform in eval() mode. - transform_on_fantasize: A boolean indicating whether to apply - the transform when called from within a `fantasize` call. """ - transform_on_eval: bool - transform_on_train: bool - transform_on_fantasize: bool - def forward(self, X: Tensor) -> Tensor: r"""Transform the inputs to a model. @@ -60,15 +47,9 @@ def forward(self, X: Tensor) -> Tensor: X: A `batch_shape x n x d`-dim tensor of inputs. Returns: - A `batch_shape x n' x d`-dim tensor of transformed inputs. + A `batch_shape x n x d`-dim tensor of transformed inputs. """ - if self.training: - if self.transform_on_train: - return self.transform(X) - elif self.transform_on_eval: - if fantasize.off() or self.transform_on_fantasize: - return self.transform(X) - return X + return self.transform(X) @abstractmethod def transform(self, X: Tensor) -> Tensor: @@ -110,44 +91,11 @@ def equals(self, other: InputTransform) -> bool: A boolean indicating if the other transform is equivalent. """ other_state_dict = other.state_dict() - return ( - type(self) == type(other) - and (self.transform_on_train == other.transform_on_train) - and (self.transform_on_eval == other.transform_on_eval) - and (self.transform_on_fantasize == other.transform_on_fantasize) - and all( - torch.allclose(v, other_state_dict[k].to(v)) - for k, v in self.state_dict().items() - ) + return type(self) == type(other) and all( + torch.allclose(v, other_state_dict[k].to(v)) + for k, v in self.state_dict().items() ) - def preprocess_transform(self, X: Tensor) -> Tensor: - r"""Apply transforms for preprocessing inputs. - - The main use cases for this method are 1) to preprocess training data - before calling `set_train_data` and 2) preprocess `X_baseline` for noisy - acquisition functions so that `X_baseline` is "preprocessed" with the - same transformations as the cached training inputs. - - Args: - X: A `batch_shape x n x d`-dim tensor of inputs. - - Returns: - A `batch_shape x n x d`-dim tensor of (transformed) inputs. - """ - if self.transform_on_train: - # We need to disable learning of bounds here. - # See why: https://github.com/pytorch/botorch/issues/1078. - if hasattr(self, "learn_bounds"): - learn_bounds = self.learn_bounds - self.learn_bounds = False - result = self.transform(X) - self.learn_bounds = learn_bounds - return result - else: - return self.transform(X) - return X - class ChainedInputTransform(InputTransform, ModuleDict): r"""An input transform representing the chaining of individual transforms.""" @@ -171,13 +119,6 @@ def __init__(self, **transforms: InputTransform) -> None: """ super().__init__(OrderedDict(transforms)) - self.transform_on_train = False - self.transform_on_eval = False - self.transform_on_fantasize = False - for tf in transforms.values(): - self.transform_on_train |= tf.transform_on_train - self.transform_on_eval |= tf.transform_on_eval - self.transform_on_fantasize |= tf.transform_on_fantasize def transform(self, X: Tensor) -> Tensor: r"""Transform the inputs to a model. @@ -222,24 +163,6 @@ def equals(self, other: InputTransform) -> bool: t1 == t2 for t1, t2 in zip(self.values(), other.values()) ) - def preprocess_transform(self, X: Tensor) -> Tensor: - r"""Apply transforms for preprocessing inputs. - - The main use cases for this method are 1) to preprocess training data - before calling `set_train_data` and 2) preprocess `X_baseline` for noisy - acquisition functions so that `X_baseline` is "preprocessed" with the - same transformations as the cached training inputs. - - Args: - X: A `batch_shape x n x d`-dim tensor of inputs. - - Returns: - A `batch_shape x n x d`-dim tensor of (transformed) inputs. - """ - for tf in self.values(): - X = tf.preprocess_transform(X) - return X - class ReversibleInputTransform(InputTransform, ABC): r"""An abstract class for a reversible input transform. @@ -323,9 +246,6 @@ def __init__( indices: Optional[List[int]] = None, bounds: Optional[Tensor] = None, batch_shape: torch.Size = torch.Size(), # noqa: B008 - transform_on_train: bool = True, - transform_on_eval: bool = True, - transform_on_fantasize: bool = True, reverse: bool = False, min_range: float = 1e-8, ) -> None: @@ -340,12 +260,6 @@ def __init__( batch_shape: The batch shape of the inputs (asssuming input tensors of shape `batch_shape x n x d`). If provided, perform individual normalization per batch, otherwise uses a single normalization. - transform_on_train: A boolean indicating whether to apply the - transforms in train() mode. Default: True. - transform_on_eval: A boolean indicating whether to apply the - transform in eval() mode. Default: True. - transform_on_fantasize: A boolean indicating whether to apply the - transform when called from within a `fantasize` call. Default: True. reverse: A boolean indicating whether the forward pass should untransform the inputs. min_range: Amount of noise to add to the range to ensure no division by @@ -378,9 +292,6 @@ def __init__( self.register_buffer("mins", mins) self.register_buffer("ranges", ranges) self._d = d - self.transform_on_train = transform_on_train - self.transform_on_eval = transform_on_eval - self.transform_on_fantasize = transform_on_fantasize self.reverse = reverse self.batch_shape = batch_shape self.min_range = min_range @@ -480,9 +391,6 @@ def __init__( d: int, indices: Optional[List[int]] = None, batch_shape: torch.Size = torch.Size(), # noqa: B008 - transform_on_train: bool = True, - transform_on_eval: bool = True, - transform_on_fantasize: bool = True, reverse: bool = False, min_std: float = 1e-8, ) -> None: @@ -495,10 +403,6 @@ def __init__( batch_shape: The batch shape of the inputs (asssuming input tensors of shape `batch_shape x n x d`). If provided, perform individual normalization per batch, otherwise uses a single normalization. - transform_on_train: A boolean indicating whether to apply the - transforms in train() mode. Default: True - transform_on_eval: A boolean indicating whether to apply the - transform in eval() mode. Default: True reverse: A boolean indicating whether the forward pass should untransform the inputs. min_std: Amount of noise to add to the standard deviation to ensure no @@ -519,9 +423,6 @@ def __init__( self.register_buffer("means", torch.zeros(*batch_shape, 1, d)) self.register_buffer("stds", torch.ones(*batch_shape, 1, d)) self._d = d - self.transform_on_train = transform_on_train - self.transform_on_eval = transform_on_eval - self.transform_on_fantasize = transform_on_fantasize self.batch_shape = batch_shape self.min_std = min_std self.reverse = reverse @@ -639,9 +540,6 @@ class Round(InputTransform, Module): def __init__( self, indices: List[int], - transform_on_train: bool = True, - transform_on_eval: bool = True, - transform_on_fantasize: bool = True, approximate: bool = True, tau: float = 1e-3, ) -> None: @@ -649,20 +547,11 @@ def __init__( Args: indices: The indices of the integer inputs. - transform_on_train: A boolean indicating whether to apply the - transforms in train() mode. Default: True. - transform_on_eval: A boolean indicating whether to apply the - transform in eval() mode. Default: True. - transform_on_fantasize: A boolean indicating whether to apply the - transform when called from within a `fantasize` call. Default: True. approximate: A boolean indicating whether approximate or exact rounding should be used. Default: approximate. tau: The temperature parameter for approximate rounding. """ super().__init__() - self.transform_on_train = transform_on_train - self.transform_on_eval = transform_on_eval - self.transform_on_fantasize = transform_on_fantasize self.register_buffer("indices", torch.tensor(indices, dtype=torch.long)) self.approximate = approximate self.tau = tau @@ -707,29 +596,17 @@ class Log10(ReversibleInputTransform, Module): def __init__( self, indices: List[int], - transform_on_train: bool = True, - transform_on_eval: bool = True, - transform_on_fantasize: bool = True, reverse: bool = False, ) -> None: r"""Initialize transform. Args: indices: The indices of the inputs to log transform. - transform_on_train: A boolean indicating whether to apply the - transforms in train() mode. Default: True. - transform_on_eval: A boolean indicating whether to apply the - transform in eval() mode. Default: True. - transform_on_fantasize: A boolean indicating whether to apply the - transform when called from within a `fantasize` call. Default: True. reverse: A boolean indicating whether the forward pass should untransform the inputs. """ super().__init__() self.register_buffer("indices", torch.tensor(indices, dtype=torch.long)) - self.transform_on_train = transform_on_train - self.transform_on_eval = transform_on_eval - self.transform_on_fantasize = transform_on_fantasize self.reverse = reverse def _transform(self, X: Tensor) -> Tensor: @@ -780,9 +657,6 @@ class Warp(ReversibleInputTransform, GPyTorchModule): def __init__( self, indices: List[int], - transform_on_train: bool = True, - transform_on_eval: bool = True, - transform_on_fantasize: bool = True, reverse: bool = False, eps: float = 1e-7, concentration1_prior: Optional[Prior] = None, @@ -793,12 +667,6 @@ def __init__( Args: indices: The indices of the inputs to warp. - transform_on_train: A boolean indicating whether to apply the - transforms in train() mode. Default: True. - transform_on_eval: A boolean indicating whether to apply the - transform in eval() mode. Default: True. - transform_on_fantasize: A boolean indicating whether to apply the - transform when called from within a `fantasize` call. Default: True. reverse: A boolean indicating whether the forward pass should untransform the inputs. eps: A small value used to clip values to be in the interval (0, 1). @@ -810,9 +678,6 @@ def __init__( """ super().__init__() self.register_buffer("indices", torch.tensor(indices, dtype=torch.long)) - self.transform_on_train = transform_on_train - self.transform_on_eval = transform_on_eval - self.transform_on_fantasize = transform_on_fantasize self.reverse = reverse self.batch_shape = batch_shape or torch.Size([]) self._X_min = eps @@ -913,105 +778,6 @@ def _untransform(self, X: Tensor) -> Tensor: return X_tf -class AppendFeatures(InputTransform, Module): - r"""A transform that appends the input with a given set of features. - - As an example, this can be used with `RiskMeasureMCObjective` to optimize risk - measures as described in [Cakmak2020risk]_. A tutorial notebook implementing the - rhoKG acqusition function introduced in [Cakmak2020risk]_ can be found at - https://botorch.org/tutorials/risk_averse_bo_with_environmental_variables. - - The steps for using this to obtain samples of a risk measure are as follows: - - - Train a model on `(x, w)` inputs and the corresponding observations; - - - Pass in an instance of `AppendFeatures` with the `feature_set` denoting the - samples of `W` as the `input_transform` to the trained model; - - - Call `posterior(...).rsample(...)` on the model with `x` inputs only to - get the joint posterior samples over `(x, w)`s, where the `w`s come - from the `feature_set`; - - - Pass these posterior samples through the `RiskMeasureMCObjective` of choice to - get the samples of the risk measure. - - Note: The samples of the risk measure obtained this way are in general biased - since the `feature_set` does not fully represent the distribution of the - environmental variable. - - Example: - >>> # We consider 1D `x` and 1D `w`, with `W` having a - >>> # uniform distribution over [0, 1] - >>> model = SingleTaskGP( - ... train_X=torch.rand(10, 2), - ... train_Y=torch.randn(10, 1), - ... input_transform=AppendFeatures(feature_set=torch.rand(10, 1)) - ... ) - >>> mll = ExactMarginalLogLikelihood(model.likelihood, model) - >>> fit_gpytorch_model(mll) - >>> test_x = torch.rand(3, 1) - >>> # `posterior_samples` is a `10 x 30 x 1`-dim tensor - >>> posterior_samples = model.posterior(test_x).rsamples(torch.size([10])) - >>> risk_measure = VaR(alpha=0.8, n_w=10) - >>> # `risk_measure_samples` is a `10 x 3`-dim tensor of samples of the - >>> # risk measure VaR - >>> risk_measure_samples = risk_measure(posterior_samples) - """ - - def __init__( - self, - feature_set: Tensor, - transform_on_train: bool = False, - transform_on_eval: bool = True, - transform_on_fantasize: bool = False, - ) -> None: - r"""Append `feature_set` to each input. - - Args: - feature_set: An `n_f x d_f`-dim tensor denoting the features to be - appended to the inputs. - transform_on_train: A boolean indicating whether to apply the - transforms in train() mode. Default: False. - transform_on_eval: A boolean indicating whether to apply the - transform in eval() mode. Default: True. - transform_on_fantasize: A boolean indicating whether to apply the - transform when called from within a `fantasize` call. Default: False. - """ - super().__init__() - if feature_set.dim() != 2: - raise ValueError("`feature_set` must be an `n_f x d_f`-dim tensor!") - self.register_buffer("feature_set", feature_set) - self.transform_on_train = transform_on_train - self.transform_on_eval = transform_on_eval - self.transform_on_fantasize = transform_on_fantasize - - def transform(self, X: Tensor) -> Tensor: - r"""Transform the inputs by appending `feature_set` to each input. - - For each `1 x d`-dim element in the input tensor, this will produce - an `n_f x (d + d_f)`-dim tensor with `feature_set` appended as the last `d_f` - dimensions. For a generic `batch_shape x q x d`-dim `X`, this translates to a - `batch_shape x (q * n_f) x (d + d_f)`-dim output, where the values corresponding - to `X[..., i, :]` are found in `output[..., i * n_f: (i + 1) * n_f, :]`. - - Note: Adding the `feature_set` on the `q-batch` dimension is necessary to avoid - introducing additional bias by evaluating the inputs on independent GP - sample paths. - - Args: - X: A `batch_shape x q x d`-dim tensor of inputs. - - Returns: - A `batch_shape x (q * n_f) x (d + d_f)`-dim tensor of appended inputs. - """ - expanded_X = X.unsqueeze(dim=-2).expand( - *X.shape[:-1], self.feature_set.shape[0], -1 - ) - expanded_features = self.feature_set.expand(*expanded_X.shape[:-1], -1) - appended_X = torch.cat([expanded_X, expanded_features], dim=-1) - return appended_X.view(*X.shape[:-2], -1, appended_X.shape[-1]) - - class FilterFeatures(InputTransform, Module): r"""A transform that filters the input with a given set of features indices. @@ -1025,21 +791,12 @@ class FilterFeatures(InputTransform, Module): def __init__( self, feature_indices: Tensor, - transform_on_train: bool = True, - transform_on_eval: bool = True, - transform_on_fantasize: bool = True, ) -> None: r"""Filter features from a model. Args: feature_set: An one-dim tensor denoting the indices of the features to be kept and fed to the model. - transform_on_train: A boolean indicating whether to apply the - transforms in train() mode. Default: True. - transform_on_eval: A boolean indicating whether to apply the - transform in eval() mode. Default: True. - transform_on_fantasize: A boolean indicating whether to apply the - transform when called from within a `fantasize` call. Default: True. """ super().__init__() if feature_indices.dim() != 1: @@ -1052,9 +809,6 @@ def __init__( ) if len(feature_indices.unique()) != len(feature_indices): raise ValueError("Elements of `feature_indices` tensor must be unique!") - self.transform_on_train = transform_on_train - self.transform_on_eval = transform_on_eval - self.transform_on_fantasize = transform_on_fantasize self.register_buffer("feature_indices", feature_indices) def transform(self, X: Tensor) -> Tensor: @@ -1062,10 +816,10 @@ def transform(self, X: Tensor) -> Tensor: feature indices and filtering out the others. Args: - X: A `batch_shape x q x d`-dim tensor of inputs. + X: A `batch_shape x n x d`-dim tensor of inputs. Returns: - A `batch_shape x q x e`-dim tensor of filtered inputs, + A `batch_shape x n x e`-dim tensor of filtered inputs, where `e` is the length of `feature_indices`. """ return X[..., self.feature_indices] @@ -1082,107 +836,3 @@ def equals(self, other: InputTransform) -> bool: if len(self.feature_indices) != len(other.feature_indices): return False return super().equals(other=other) - - -class InputPerturbation(InputTransform, Module): - r"""A transform that adds the set of perturbations to the given input. - - Similar to `AppendFeatures`, this can be used with `RiskMeasureMCObjective` - to optimize risk measures. See `AppendFeatures` for additional discussion - on optimizing risk measures. - - A tutorial notebook using this with `qNoisyExpectedImprovement` can be found at - https://botorch.org/tutorials/risk_averse_bo_with_input_perturbations. - """ - - def __init__( - self, - perturbation_set: Union[Tensor, Callable[[Tensor], Tensor]], - bounds: Optional[Tensor] = None, - multiplicative: bool = False, - transform_on_train: bool = False, - transform_on_eval: bool = True, - transform_on_fantasize: bool = False, - ) -> None: - r"""Add `perturbation_set` to each input. - - Args: - perturbation_set: An `n_p x d`-dim tensor denoting the perturbations - to be added to the inputs. Alternatively, this can be a callable that - returns `batch x n_p x d`-dim tensor of perturbations for input of - shape `batch x d`. This is useful for heteroscedastic perturbations. - bounds: A `2 x d`-dim tensor of lower and upper bounds for each - column of the input. If given, the perturbed inputs will be - clamped to these bounds. - multiplicative: A boolean indicating whether the input perturbations - are additive or multiplicative. If True, inputs will be multiplied - with the perturbations. - transform_on_train: A boolean indicating whether to apply the - transforms in train() mode. Default: False. - transform_on_eval: A boolean indicating whether to apply the - transform in eval() mode. Default: True. - transform_on_fantasize: A boolean indicating whether to apply the - transform when called from within a `fantasize` call. Default: False. - """ - super().__init__() - if isinstance(perturbation_set, Tensor): - if perturbation_set.dim() != 2: - raise ValueError("`perturbation_set` must be an `n_p x d`-dim tensor!") - self.register_buffer("perturbation_set", perturbation_set) - else: - self.perturbation_set = perturbation_set - if bounds is not None: - if ( - isinstance(perturbation_set, Tensor) - and bounds.shape[-1] != perturbation_set.shape[-1] - ): - raise ValueError( - "`bounds` must have the same number of columns (last dimension) as " - f"the `perturbation_set`! Got {bounds.shape[-1]} and " - f"{perturbation_set.shape[-1]}." - ) - self.register_buffer("bounds", bounds) - else: - self.bounds = None - self.multiplicative = multiplicative - self.transform_on_train = transform_on_train - self.transform_on_eval = transform_on_eval - self.transform_on_fantasize = transform_on_fantasize - - def transform(self, X: Tensor) -> Tensor: - r"""Transform the inputs by adding `perturbation_set` to each input. - - For each `1 x d`-dim element in the input tensor, this will produce - an `n_p x d`-dim tensor with the `perturbation_set` added to the input. - For a generic `batch_shape x q x d`-dim `X`, this translates to a - `batch_shape x (q * n_p) x d`-dim output, where the values corresponding - to `X[..., i, :]` are found in `output[..., i * n_w: (i + 1) * n_w, :]`. - - Note: Adding the `perturbation_set` on the `q-batch` dimension is necessary - to avoid introducing additional bias by evaluating the inputs on independent - GP sample paths. - - Args: - X: A `batch_shape x q x d`-dim tensor of inputs. - - Returns: - A `batch_shape x (q * n_p) x d`-dim tensor of perturbed inputs. - """ - if isinstance(self.perturbation_set, Tensor): - perturbations = self.perturbation_set - else: - perturbations = self.perturbation_set(X) - expanded_X = X.unsqueeze(dim=-2).expand( - *X.shape[:-1], perturbations.shape[-2], -1 - ) - expanded_perturbations = perturbations.expand(*expanded_X.shape[:-1], -1) - if self.multiplicative: - perturbed_inputs = expanded_X * expanded_perturbations - else: - perturbed_inputs = expanded_X + expanded_perturbations - perturbed_inputs = perturbed_inputs.reshape(*X.shape[:-2], -1, X.shape[-1]) - if self.bounds is not None: - perturbed_inputs = torch.maximum( - torch.minimum(perturbed_inputs, self.bounds[1]), self.bounds[0] - ) - return perturbed_inputs diff --git a/botorch/models/transforms/input_augmentation.py b/botorch/models/transforms/input_augmentation.py new file mode 100644 index 0000000000..67d2a7f0d5 --- /dev/null +++ b/botorch/models/transforms/input_augmentation.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +Input Augmentation Transformations. + +These classes implement a variety of transformations for +input parameters that are applied only to the test inputs +at the `posterior` call. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Optional, Callable, Union + +import torch +from torch import Tensor +from torch.nn import Module + + +class InputAugmentationTransform(Module, ABC): + r"""Abstract base class for input augmentation transforms.""" + + @abstractmethod + def forward(self, X: Tensor) -> Tensor: + r"""Transform the inputs to a model. + + Args: + X: A `batch_shape x q x d`-dim tensor of inputs. + + Returns: + A `batch_shape x q' x d'`-dim tensor of transformed inputs, where `q'` + is generally an integer multiple of `q` and `d' > d`, both determined + by the transform arguments. + """ + pass # pragma: no cover + + def equals(self, other: InputAugmentationTransform) -> bool: + r"""Check if another input augmentation transform is equivalent. + + Note: The reason that a custom equals method is defined rather than + defining an __eq__ method is because defining an __eq__ method sets + the __hash__ method to None. Hashing modules is currently used in + pytorch. See https://github.com/pytorch/pytorch/issues/7733. + + Args: + other: Another input augmentation transform. + + Returns: + A boolean indicating if the other transform is equivalent. + """ + other_state_dict = other.state_dict() + return type(self) == type(other) and all( + torch.allclose(v, other_state_dict[k].to(v)) + for k, v in self.state_dict().items() + ) + + +class AppendFeatures(InputAugmentationTransform): + r"""A transform that appends the input with a given set of features. + + As an example, this can be used with `RiskMeasureMCObjective` to optimize risk + measures as described in [Cakmak2020risk]_. A tutorial notebook implementing the + rhoKG acqusition function introduced in [Cakmak2020risk]_ can be found at + https://botorch.org/tutorials/risk_averse_bo_with_environmental_variables. + + The steps for using this to obtain samples of a risk measure are as follows: + + - Train a model on `(x, w)` inputs and the corresponding observations; + + - Pass in an instance of `AppendFeatures` with the `feature_set` denoting the + samples of `W` as the `input_transform` to the trained model; + + - Call `posterior(...).rsample(...)` on the model with `x` inputs only to + get the joint posterior samples over `(x, w)`s, where the `w`s come + from the `feature_set`; + + - Pass these posterior samples through the `RiskMeasureMCObjective` of choice to + get the samples of the risk measure. + + Note: The samples of the risk measure obtained this way are in general biased + since the `feature_set` does not fully represent the distribution of the + environmental variable. + + Example: + >>> # We consider 1D `x` and 1D `w`, with `W` having a + >>> # uniform distribution over [0, 1] + >>> model = SingleTaskGP( + ... train_X=torch.rand(10, 2), + ... train_Y=torch.randn(10, 1), + ... input_augmentation_transform=AppendFeatures(feature_set=torch.rand(10, 1)) + ... ) + >>> mll = ExactMarginalLogLikelihood(model.likelihood, model) + >>> fit_gpytorch_model(mll) + >>> test_x = torch.rand(3, 1) + >>> # `posterior_samples` is a `10 x 30 x 1`-dim tensor + >>> posterior_samples = model.posterior(test_x).rsamples(torch.size([10])) + >>> risk_measure = VaR(alpha=0.8, n_w=10) + >>> # `risk_measure_samples` is a `10 x 3`-dim tensor of samples of the + >>> # risk measure VaR + >>> risk_measure_samples = risk_measure(posterior_samples) + """ + + def __init__( + self, + feature_set: Tensor, + ) -> None: + r"""Append `feature_set` to each input. + + Args: + feature_set: An `n_f x d_f`-dim tensor denoting the features to be + appended to the inputs. + """ + super().__init__() + if feature_set.dim() != 2: + raise ValueError("`feature_set` must be an `n_f x d_f`-dim tensor!") + self.register_buffer("feature_set", feature_set) + + def forward(self, X: Tensor) -> Tensor: + r"""Transform the inputs by appending `feature_set` to each input. + + For each `1 x d`-dim element in the input tensor, this will produce + an `n_f x (d + d_f)`-dim tensor with `feature_set` appended as the last `d_f` + dimensions. For a generic `batch_shape x q x d`-dim `X`, this translates to a + `batch_shape x (q * n_f) x (d + d_f)`-dim output, where the values corresponding + to `X[..., i, :]` are found in `output[..., i * n_f: (i + 1) * n_f, :]`. + + Note: Adding the `feature_set` on the `q-batch` dimension is necessary to avoid + introducing additional bias by evaluating the inputs on independent GP + sample paths. + + Args: + X: A `batch_shape x q x d`-dim tensor of inputs. + + Returns: + A `batch_shape x (q * n_f) x (d + d_f)`-dim tensor of appended inputs. + """ + expanded_X = X.unsqueeze(dim=-2).expand( + *X.shape[:-1], self.feature_set.shape[0], -1 + ) + expanded_features = self.feature_set.expand(*expanded_X.shape[:-1], -1) + appended_X = torch.cat([expanded_X, expanded_features], dim=-1) + return appended_X.view(*X.shape[:-2], -1, appended_X.shape[-1]) + + +class InputPerturbation(InputAugmentationTransform): + r"""A transform that adds the set of perturbations to the given input. + + Similar to `AppendFeatures`, this can be used with `RiskMeasureMCObjective` + to optimize risk measures. See `AppendFeatures` for additional discussion + on optimizing risk measures. + + A tutorial notebook using this with `qNoisyExpectedImprovement` can be found at + https://botorch.org/tutorials/risk_averse_bo_with_input_perturbations. + """ + + def __init__( + self, + perturbation_set: Union[Tensor, Callable[[Tensor], Tensor]], + bounds: Optional[Tensor] = None, + multiplicative: bool = False, + ) -> None: + r"""Add `perturbation_set` to each input. + + Args: + perturbation_set: An `n_p x d`-dim tensor denoting the perturbations + to be added to the inputs. Alternatively, this can be a callable that + returns `batch x n_p x d`-dim tensor of perturbations for input of + shape `batch x d`. This is useful for heteroscedastic perturbations. + bounds: A `2 x d`-dim tensor of lower and upper bounds for each + column of the input. If given, the perturbed inputs will be + clamped to these bounds. + multiplicative: A boolean indicating whether the input perturbations + are additive or multiplicative. If True, inputs will be multiplied + with the perturbations. + """ + super().__init__() + if isinstance(perturbation_set, Tensor): + if perturbation_set.dim() != 2: + raise ValueError("`perturbation_set` must be an `n_p x d`-dim tensor!") + self.register_buffer("perturbation_set", perturbation_set) + else: + self.perturbation_set = perturbation_set + if bounds is not None: + if ( + isinstance(perturbation_set, Tensor) + and bounds.shape[-1] != perturbation_set.shape[-1] + ): + raise ValueError( + "`bounds` must have the same number of columns (last dimension) as " + f"the `perturbation_set`! Got {bounds.shape[-1]} and " + f"{perturbation_set.shape[-1]}." + ) + self.register_buffer("bounds", bounds) + else: + self.bounds = None + self.multiplicative = multiplicative + + def forward(self, X: Tensor) -> Tensor: + r"""Transform the inputs by adding `perturbation_set` to each input. + + For each `1 x d`-dim element in the input tensor, this will produce + an `n_p x d`-dim tensor with the `perturbation_set` added to the input. + For a generic `batch_shape x q x d`-dim `X`, this translates to a + `batch_shape x (q * n_p) x d`-dim output, where the values corresponding + to `X[..., i, :]` are found in `output[..., i * n_w: (i + 1) * n_w, :]`. + + Note: Adding the `perturbation_set` on the `q-batch` dimension is necessary + to avoid introducing additional bias by evaluating the inputs on independent + GP sample paths. + + Args: + X: A `batch_shape x q x d`-dim tensor of inputs. + + Returns: + A `batch_shape x (q * n_p) x d`-dim tensor of perturbed inputs. + """ + if isinstance(self.perturbation_set, Tensor): + perturbations = self.perturbation_set + else: + perturbations = self.perturbation_set(X) + expanded_X = X.unsqueeze(dim=-2).expand( + *X.shape[:-1], perturbations.shape[-2], -1 + ) + expanded_perturbations = perturbations.expand(*expanded_X.shape[:-1], -1) + if self.multiplicative: + perturbed_inputs = expanded_X * expanded_perturbations + else: + perturbed_inputs = expanded_X + expanded_perturbations + perturbed_inputs = perturbed_inputs.reshape(*X.shape[:-2], -1, X.shape[-1]) + if self.bounds is not None: + perturbed_inputs = torch.maximum( + torch.minimum(perturbed_inputs, self.bounds[1]), self.bounds[0] + ) + return perturbed_inputs diff --git a/botorch/models/utils.py b/botorch/models/utils.py index 42b10d0fda..ff4ad69760 100644 --- a/botorch/models/utils.py +++ b/botorch/models/utils.py @@ -17,7 +17,6 @@ import torch from botorch import settings from botorch.exceptions import InputDataError, InputDataWarning -from botorch.settings import _Flag from gpytorch import settings as gpt_settings from gpytorch.module import Module from gpytorch.utils.broadcasting import _mul_broadcast_shape @@ -281,8 +280,3 @@ def gpt_posterior_settings(): gpt_settings.detach_test_caches(settings.propagate_grads.off()) ) yield - - -class fantasize(_Flag): - r"""A flag denoting whether we are currently in a `fantasize` context.""" - _state: bool = False