From 569ad1613b875d175a23407fdc23b5db3f7c9537 Mon Sep 17 00:00:00 2001 From: Khurram Ghani Date: Fri, 20 Sep 2024 13:17:39 +0100 Subject: [PATCH] Flatten add kernels dim into features --- .../basis_functions/fourier_features/base.py | 23 +++++++++++++++---- .../fourier_features/quadrature/gaussian.py | 2 +- .../fourier_features/random/base.py | 18 +++++++++++---- .../fourier_features/test_random.py | 6 ----- 4 files changed, 33 insertions(+), 16 deletions(-) diff --git a/gpflux/layers/basis_functions/fourier_features/base.py b/gpflux/layers/basis_functions/fourier_features/base.py index 73507ae5..c1a9c252 100644 --- a/gpflux/layers/basis_functions/fourier_features/base.py +++ b/gpflux/layers/basis_functions/fourier_features/base.py @@ -45,14 +45,17 @@ def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: M self.n_components = n_components if isinstance(kernel, gpflow.kernels.MultioutputKernel): self.is_batched = True + self.is_multioutput = True self.batch_size = kernel.num_latent_gps self.sub_kernels = kernel.latent_kernels elif isinstance(kernel, gpflow.kernels.Combination): self.is_batched = True + self.is_multioutput = False self.batch_size = len(kernel.kernels) self.sub_kernels = kernel.kernels else: self.is_batched = False + self.is_multioutput = False self.batch_size = 1 self.sub_kernels = [] @@ -68,7 +71,7 @@ def call(self, inputs: TensorType) -> tf.Tensor: :param inputs: The evaluation points, a tensor with the shape ``[N, D]``. - :return: A tensor with the shape ``[N, M]``, or shape ``[P, N, M]'' in the multioutput case. + :return: A tensor with the shape ``[N, M]``, or shape ``[P, N, M]'' in the batched case. """ if self.is_batched: X = [tf.divide(inputs, k.lengthscales) for k in self.sub_kernels] @@ -78,6 +81,13 @@ def call(self, inputs: TensorType) -> tf.Tensor: const = self._compute_constant() # [] or [P, 1, 1] bases = self._compute_bases(X) # [N, M] or [P, N, M] output = const * bases + + # For combination kernels, remove batch dimension and instead concatenate into the + # feature dimension. + if self.is_batched and not self.is_multioutput: + output = tf.transpose(output, perm=[1, 2, 0]) # [N, M, P] + output = tf.reshape(output, [output.shape[0], -1]) # [N, M*P] + tf.ensure_shape(output, self.compute_output_shape(inputs.shape)) return output @@ -90,12 +100,12 @@ def compute_output_shape(self, input_shape: ShapeType) -> tf.TensorShape: # TODO: Keras docs say "If the layer has not been built, this method # will call `build` on the layer." -- do we need to do so? tensor_shape = tf.TensorShape(input_shape).with_rank(2) - output_dim = self._compute_output_dim(input_shape) + output_dim = self.compute_output_dim(input_shape) trailing_shape = tensor_shape[:-1].concatenate(output_dim) - if self.is_batched: + if self.is_multioutput: return tf.TensorShape([self.batch_size]).concatenate(trailing_shape) # [P, N, M] else: - return trailing_shape # [N, M] + return trailing_shape # [N, M] or [N, M*P] def get_config(self) -> Mapping: """ @@ -115,7 +125,10 @@ def get_config(self) -> Mapping: return config @abstractmethod - def _compute_output_dim(self, input_shape: ShapeType) -> int: + def compute_output_dim(self, input_shape: ShapeType) -> int: + """ + Compute the output dimension of the layer. + """ pass @abstractmethod diff --git a/gpflux/layers/basis_functions/fourier_features/quadrature/gaussian.py b/gpflux/layers/basis_functions/fourier_features/quadrature/gaussian.py index 2391bb13..f24c543f 100644 --- a/gpflux/layers/basis_functions/fourier_features/quadrature/gaussian.py +++ b/gpflux/layers/basis_functions/fourier_features/quadrature/gaussian.py @@ -71,7 +71,7 @@ def build(self, input_shape: ShapeType) -> None: self.factors = tf.Variable(initial_value=omegas_value, trainable=False) # (M^D,) super(QuadratureFourierFeatures, self).build(input_shape) - def _compute_output_dim(self, input_shape: ShapeType) -> int: + def compute_output_dim(self, input_shape: ShapeType) -> int: input_dim = input_shape[-1] return 2 * self.n_components ** input_dim diff --git a/gpflux/layers/basis_functions/fourier_features/random/base.py b/gpflux/layers/basis_functions/fourier_features/random/base.py index bdfaa63c..4e877230 100644 --- a/gpflux/layers/basis_functions/fourier_features/random/base.py +++ b/gpflux/layers/basis_functions/fourier_features/random/base.py @@ -199,8 +199,13 @@ class RandomFourierFeatures(RandomFourierFeaturesBase): from phase-shifted cosines :class:`RandomFourierFeaturesCosine` :cite:p:`sutherland2015error`. """ - def _compute_output_dim(self, input_shape: ShapeType) -> int: - return 2 * self.n_components + def compute_output_dim(self, input_shape: ShapeType) -> int: + # For combination kernels, the number of features is multiplied by the number of + # sub-kernels. + dim = 2 * self.n_components + if self.is_batched and not self.is_multioutput: + dim *= self.batch_size + return dim def _compute_bases(self, inputs: TensorType) -> tf.Tensor: """ @@ -281,8 +286,13 @@ def _bias_build(self, n_components: int) -> None: def _bias_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType: return tf.random.uniform(shape=shape, maxval=2.0 * np.pi, dtype=dtype) - def _compute_output_dim(self, input_shape: ShapeType) -> int: - return self.n_components + def compute_output_dim(self, input_shape: ShapeType) -> int: + # For combination kernels, the number of features is multiplied by the number of + # sub-kernels. + dim = self.n_components + if self.is_batched and not self.is_multioutput: + dim *= self.batch_size + return dim def _compute_bases(self, inputs: TensorType) -> tf.Tensor: """ diff --git a/tests/gpflux/layers/basis_functions/fourier_features/test_random.py b/tests/gpflux/layers/basis_functions/fourier_features/test_random.py index ac1f38a1..6ab97f72 100644 --- a/tests/gpflux/layers/basis_functions/fourier_features/test_random.py +++ b/tests/gpflux/layers/basis_functions/fourier_features/test_random.py @@ -157,9 +157,6 @@ def test_multi_random_fourier_features_can_approximate_kernel_multidim( v = fourier_features(y) approx_kernel_matrix = u @ tf.linalg.matrix_transpose(v) - if isinstance(multi_kernel, gpflow.kernels.Sum): - approx_kernel_matrix = tf.reduce_sum(approx_kernel_matrix, axis=0) - if isinstance(multi_kernel, gpflow.kernels.MultioutputKernel): actual_kernel_matrix = multi_kernel.K(x, y, full_output_cov=False) else: @@ -231,9 +228,6 @@ def test_multi_random_fourier_feature_layer_compute_covariance_of_inducing_varia u = fourier_features(x_new) approx_kernel_matrix = u @ tf.linalg.matrix_transpose(u) - if isinstance(multi_kernel, gpflow.kernels.Sum): - approx_kernel_matrix = tf.reduce_sum(approx_kernel_matrix, axis=0) - if isinstance(multi_kernel, gpflow.kernels.MultioutputKernel): actual_kernel_matrix = multi_kernel.K(x_new, x_new, full_output_cov=False) else: